From b5648b9fa8dbebd0d57a71eb6d840d98051acafc Mon Sep 17 00:00:00 2001 From: Tanguy MARCHAND <84329436+tanguy-marchand@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:30:31 +0100 Subject: [PATCH] first commit --- .ci/self_hosted_conda_creation.sh | 25 + .ci/self_hosted_conda_removal.sh | 8 + .github/ISSUE_TEMPLATE/bug_report.md | 28 + .github/ISSUE_TEMPLATE/feature_request.md | 20 + .github/PULL_REQUEST_TEMPLATE.md | 16 + .github/workflows/pr_push_validation.yml | 68 + .gitignore | 184 ++ .pre-commit-config.yaml | 33 + LICENSE | 21 + README.md | 115 + fedpydeseq2/__init__.py | 1 + fedpydeseq2/core/deseq2_core/__init__.py | 7 + .../build_design_matrix/__init__.py | 5 + .../build_design_matrix.py | 144 ++ .../build_design_matrix/substeps.py | 248 ++ .../compute_cook_distance/__init__.py | 5 + .../compute_cook_distance.py | 131 + .../compute_cook_distance/substeps.py | 116 + .../compute_size_factors/__init__.py | 5 + .../compute_size_factors.py | 104 + .../compute_size_factors/substeps.py | 108 + .../core/deseq2_core/deseq2_full_pipe.py | 189 ++ .../deseq2_lfc_dispersions/__init__.py | 3 + .../compute_MAP_dispersions/__init__.py | 5 + .../compute_MAP_dispersions.py | 101 + .../compute_MAP_dispersions/substeps.py | 56 + .../compute_dispersion_prior/__init__.py | 3 + .../compute_dispersion_prior.py | 100 + .../compute_dispersion_prior/substeps.py | 279 +++ .../compute_dispersion_prior/utils.py | 25 + .../compute_genewise_dispersions/__init__.py | 5 + .../compute_MoM_dispersions/__init__.py | 5 + .../compute_MoM_dispersions.py | 124 + .../compute_rough_dispersions.py | 124 + .../compute_MoM_dispersions/substeps.py | 301 +++ .../compute_genewise_dispersions.py | 192 ++ .../get_num_replicates/__init__.py | 5 + .../get_num_replicates/get_num_replicates.py | 83 + .../get_num_replicates/substeps.py | 109 + .../compute_genewise_dispersions/substeps.py | 110 + .../compute_lfc/__init__.py | 5 + .../compute_lfc/compute_lfc.py | 172 ++ .../compute_lfc/substeps.py | 400 +++ .../compute_lfc/utils.py | 58 + .../deseq2_lfc_dispersions.py | 173 ++ .../core/deseq2_core/deseq2_stats/__init__.py | 3 + .../deseq2_stats/compute_padj/__init__.py | 5 + .../deseq2_stats/compute_padj/compute_padj.py | 95 + .../deseq2_stats/compute_padj/substeps.py | 148 ++ .../deseq2_stats/cooks_filtering/__init__.py | 5 + .../cooks_filtering/cooks_filtering.py | 187 ++ .../deseq2_stats/cooks_filtering/substeps.py | 533 ++++ .../deseq2_core/deseq2_stats/deseq2_stats.py | 102 + .../deseq2_stats/wald_tests/__init__.py | 1 + .../deseq2_stats/wald_tests/substeps.py | 161 ++ .../deseq2_stats/wald_tests/wald_tests.py | 71 + .../deseq2_core/replace_outliers/__init__.py | 3 + .../replace_outliers/replace_outliers.py | 164 ++ .../deseq2_core/replace_outliers/substeps.py | 312 +++ .../replace_refitted_values/__init__.py | 3 + .../replace_refitted_values.py | 99 + .../core/deseq2_core/save_pipeline_results.py | 155 ++ fedpydeseq2/core/deseq2_strategy.py | 381 +++ fedpydeseq2/core/fed_algorithms/__init__.py | 17 + .../compute_trimmed_mean/__init__.py | 3 + .../compute_trimmed_mean.py | 194 ++ .../compute_trimmed_mean/substeps.py | 869 +++++++ .../compute_trimmed_mean/utils.py | 52 + .../dispersions_grid_search/__init__.py | 3 + .../dispersions_grid_search.py | 136 ++ .../dispersions_grid_search/substeps.py | 254 ++ .../core/fed_algorithms/fed_PQN/__init__.py | 3 + .../core/fed_algorithms/fed_PQN/fed_PQN.py | 133 + .../core/fed_algorithms/fed_PQN/substeps.py | 742 ++++++ .../core/fed_algorithms/fed_PQN/utils.py | 257 ++ .../core/fed_algorithms/fed_irls/__init__.py | 3 + .../core/fed_algorithms/fed_irls/fed_irls.py | 197 ++ .../core/fed_algorithms/fed_irls/substeps.py | 381 +++ .../core/fed_algorithms/fed_irls/utils.py | 81 + fedpydeseq2/core/utils/__init__.py | 14 + fedpydeseq2/core/utils/aggregation.py | 42 + fedpydeseq2/core/utils/compute_lfc_utils.py | 78 + fedpydeseq2/core/utils/design_matrix.py | 140 ++ fedpydeseq2/core/utils/layers/__init__.py | 6 + .../utils/layers/build_layers/__init__.py | 31 + .../core/utils/layers/build_layers/cooks.py | 137 ++ .../layers/build_layers/fit_lin_mu_hat.py | 77 + .../layers/build_layers/hat_diagonals.py | 216 ++ .../core/utils/layers/build_layers/mu_hat.py | 103 + .../utils/layers/build_layers/mu_layer.py | 157 ++ .../layers/build_layers/normed_counts.py | 56 + .../core/utils/layers/build_layers/sqerror.py | 91 + .../core/utils/layers/build_layers/y_hat.py | 60 + .../core/utils/layers/build_refit_adata.py | 107 + fedpydeseq2/core/utils/layers/cooks_layer.py | 328 +++ fedpydeseq2/core/utils/layers/joblib_utils.py | 32 + .../layers/reconstruct_adatas_decorator.py | 250 ++ fedpydeseq2/core/utils/layers/utils.py | 214 ++ fedpydeseq2/core/utils/logging/__init__.py | 3 + .../core/utils/logging/default_config.ini | 21 + .../core/utils/logging/logging_decorators.py | 215 ++ fedpydeseq2/core/utils/mle.py | 305 +++ fedpydeseq2/core/utils/negative_binomial.py | 149 ++ fedpydeseq2/core/utils/pass_on_results.py | 39 + fedpydeseq2/core/utils/pipe_steps.py | 140 ++ fedpydeseq2/core/utils/stat_utils.py | 211 ++ fedpydeseq2/fedpydeseq2_pipeline.py | 149 ++ fedpydeseq2/substra_utils/__init__.py | 0 .../credentials/credentials-template.yaml | 9 + .../dataset-datasamples-keys-template.yaml | 6 + .../substra_utils/federated_experiment.py | 490 ++++ fedpydeseq2/substra_utils/utils.py | 186 ++ poetry.lock | 2164 +++++++++++++++++ pyproject.toml | 99 + tests/__init__.py | 0 tests/conftest.py | 64 + tests/deseq2_end_to_end/__init__.py | 0 tests/deseq2_end_to_end/test_deseq2_pipe.py | 316 +++ .../test_deseq2_pipe_local.py | 351 +++ .../test_deseq2_pipe_utils.py | 109 + tests/paths_default.json | 1 + tests/tcga_testing_pipe.py | 181 ++ tests/unit_tests/__init__.py | 0 tests/unit_tests/deseq2_core/__init__.py | 1 + .../deseq2_lfc_dispersions/__init__.py | 1 + .../compute_genewise_dispersions/__init__.py | 15 + .../test_MoM_dispersions.py | 454 ++++ .../test_compute_mu_hat.py | 310 +++ .../test_dispersions_from_mu_hat.py | 861 +++++++ ...test_genewise_dispersions_single_factor.py | 748 ++++++ .../test_get_num_replicates.py | 477 ++++ .../utils_genewise_dispersions.py | 118 + .../compute_lfc/__init__.py | 1 + .../compute_lfc/compute_lfc_test_pipe.py | 569 +++++ .../compute_lfc/compute_lfc_tester.py | 425 ++++ .../compute_lfc/substeps.py | 230 ++ .../compute_lfc/test_compute_lfc.py | 303 +++ .../test_MAP_dispersions.py | 792 ++++++ .../test_MAP_dispersions_filtering.py | 469 ++++ .../test_trend_curve.py | 417 ++++ .../deseq2_core/deseq2_stats/__init__.py | 1 + .../deseq2_stats/test_compute_padj.py | 552 +++++ .../deseq2_stats/test_cooks_filtering.py | 544 +++++ .../deseq2_stats/test_wald_tests.py | 535 ++++ .../deseq2_core/test_cooks_distances.py | 538 ++++ .../deseq2_core/test_design_matrices.py | 514 ++++ .../deseq2_core/test_refit_cooks.py | 757 ++++++ .../deseq2_core/test_replace_cooks.py | 576 +++++ .../deseq2_core/test_save_pipeline_results.py | 429 ++++ .../deseq2_core/test_size_factors.py | 466 ++++ tests/unit_tests/fed_algorithms/__init__.py | 1 + .../fed_algorithms/fed_IRLS/__init__.py | 0 .../fed_IRLS/fed_IRLS_tester.py | 370 +++ .../fed_algorithms/fed_IRLS/irls_test_pipe.py | 157 ++ .../fed_IRLS/test_IRLS_utils.py | 156 ++ .../fed_algorithms/fed_IRLS/test_irls.py | 173 ++ .../fed_algorithms/fed_IRLS_PQN_tester.py | 150 ++ .../fed_prox_quasi_newton/__init__.py | 1 + .../fed_prox_quasi_newton/fed_pqn_tester.py | 380 +++ .../fed_prox_quasi_newton/pqn_test_pipe.py | 312 +++ .../test_fed_prox_newton_utils.py | 184 ++ .../fed_prox_quasi_newton/test_pqn.py | 199 ++ .../trimmed_mean_strategy/__init__.py | 0 .../trimmed_mean_strategy/opener/__init__.py | 0 .../opener/description.md | 1 + .../trimmed_mean_strategy/opener/opener.py | 52 + .../test_trimmed_mean_strategy.py | 229 ++ .../trimmed_mean_strategy.py | 196 ++ tests/unit_tests/layers/__init__.py | 1 + tests/unit_tests/layers/layers_tester.py | 528 ++++ tests/unit_tests/layers/opener/__init__.py | 1 + tests/unit_tests/layers/opener/description.md | 1 + tests/unit_tests/layers/opener/opener.py | 47 + .../layers/test_reconstruct_layers.py | 274 +++ tests/unit_tests/layers/test_utils.py | 135 + tests/unit_tests/layers/utils.py | 112 + .../unit_tests/unit_test_helpers/__init__.py | 0 tests/unit_tests/unit_test_helpers/levels.py | 72 + .../pass_on_first_shared_state.py | 31 + .../unit_test_helpers/set_local_reference.py | 112 + .../unit_test_helpers/unit_tester.py | 258 ++ tests/unit_tests/utils/test_mle.py | 30 + 182 files changed, 32385 insertions(+) create mode 100755 .ci/self_hosted_conda_creation.sh create mode 100755 .ci/self_hosted_conda_removal.sh create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/workflows/pr_push_validation.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 LICENSE create mode 100644 README.md create mode 100644 fedpydeseq2/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/build_design_matrix/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/build_design_matrix/build_design_matrix.py create mode 100644 fedpydeseq2/core/deseq2_core/build_design_matrix/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/compute_cook_distance/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/compute_cook_distance/compute_cook_distance.py create mode 100644 fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/compute_size_factors/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/compute_size_factors/compute_size_factors.py create mode 100644 fedpydeseq2/core/deseq2_core/compute_size_factors/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_full_pipe.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/compute_MAP_dispersions.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/compute_dispersion_prior.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/utils.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_MoM_dispersions.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_rough_dispersions.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_genewise_dispersions.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/get_num_replicates.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/utils.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/deseq2_lfc_dispersions.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/compute_padj.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/cooks_filtering.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/deseq2_stats.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/wald_tests.py create mode 100644 fedpydeseq2/core/deseq2_core/replace_outliers/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/replace_outliers/replace_outliers.py create mode 100644 fedpydeseq2/core/deseq2_core/replace_outliers/substeps.py create mode 100644 fedpydeseq2/core/deseq2_core/replace_refitted_values/__init__.py create mode 100644 fedpydeseq2/core/deseq2_core/replace_refitted_values/replace_refitted_values.py create mode 100644 fedpydeseq2/core/deseq2_core/save_pipeline_results.py create mode 100644 fedpydeseq2/core/deseq2_strategy.py create mode 100644 fedpydeseq2/core/fed_algorithms/__init__.py create mode 100644 fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/__init__.py create mode 100644 fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/compute_trimmed_mean.py create mode 100644 fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/substeps.py create mode 100644 fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/utils.py create mode 100644 fedpydeseq2/core/fed_algorithms/dispersions_grid_search/__init__.py create mode 100644 fedpydeseq2/core/fed_algorithms/dispersions_grid_search/dispersions_grid_search.py create mode 100644 fedpydeseq2/core/fed_algorithms/dispersions_grid_search/substeps.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_PQN/__init__.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_PQN/fed_PQN.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_PQN/substeps.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_PQN/utils.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_irls/__init__.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_irls/fed_irls.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py create mode 100644 fedpydeseq2/core/fed_algorithms/fed_irls/utils.py create mode 100644 fedpydeseq2/core/utils/__init__.py create mode 100644 fedpydeseq2/core/utils/aggregation.py create mode 100644 fedpydeseq2/core/utils/compute_lfc_utils.py create mode 100644 fedpydeseq2/core/utils/design_matrix.py create mode 100644 fedpydeseq2/core/utils/layers/__init__.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/__init__.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/cooks.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/fit_lin_mu_hat.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/hat_diagonals.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/mu_hat.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/mu_layer.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/normed_counts.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/sqerror.py create mode 100644 fedpydeseq2/core/utils/layers/build_layers/y_hat.py create mode 100644 fedpydeseq2/core/utils/layers/build_refit_adata.py create mode 100644 fedpydeseq2/core/utils/layers/cooks_layer.py create mode 100644 fedpydeseq2/core/utils/layers/joblib_utils.py create mode 100644 fedpydeseq2/core/utils/layers/reconstruct_adatas_decorator.py create mode 100644 fedpydeseq2/core/utils/layers/utils.py create mode 100644 fedpydeseq2/core/utils/logging/__init__.py create mode 100644 fedpydeseq2/core/utils/logging/default_config.ini create mode 100644 fedpydeseq2/core/utils/logging/logging_decorators.py create mode 100644 fedpydeseq2/core/utils/mle.py create mode 100644 fedpydeseq2/core/utils/negative_binomial.py create mode 100644 fedpydeseq2/core/utils/pass_on_results.py create mode 100644 fedpydeseq2/core/utils/pipe_steps.py create mode 100644 fedpydeseq2/core/utils/stat_utils.py create mode 100644 fedpydeseq2/fedpydeseq2_pipeline.py create mode 100644 fedpydeseq2/substra_utils/__init__.py create mode 100644 fedpydeseq2/substra_utils/credentials/credentials-template.yaml create mode 100644 fedpydeseq2/substra_utils/credentials/dataset-datasamples-keys-template.yaml create mode 100644 fedpydeseq2/substra_utils/federated_experiment.py create mode 100644 fedpydeseq2/substra_utils/utils.py create mode 100644 poetry.lock create mode 100644 pyproject.toml create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/deseq2_end_to_end/__init__.py create mode 100644 tests/deseq2_end_to_end/test_deseq2_pipe.py create mode 100644 tests/deseq2_end_to_end/test_deseq2_pipe_local.py create mode 100644 tests/deseq2_end_to_end/test_deseq2_pipe_utils.py create mode 100644 tests/paths_default.json create mode 100644 tests/tcga_testing_pipe.py create mode 100644 tests/unit_tests/__init__.py create mode 100644 tests/unit_tests/deseq2_core/__init__.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/__init__.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_MoM_dispersions.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_compute_mu_hat.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_dispersions_from_mu_hat.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_genewise_dispersions_single_factor.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_get_num_replicates.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/utils_genewise_dispersions.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_test_pipe.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_tester.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/test_compute_lfc.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions_filtering.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_trend_curve.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_stats/__init__.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_stats/test_compute_padj.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_stats/test_cooks_filtering.py create mode 100644 tests/unit_tests/deseq2_core/deseq2_stats/test_wald_tests.py create mode 100644 tests/unit_tests/deseq2_core/test_cooks_distances.py create mode 100644 tests/unit_tests/deseq2_core/test_design_matrices.py create mode 100644 tests/unit_tests/deseq2_core/test_refit_cooks.py create mode 100644 tests/unit_tests/deseq2_core/test_replace_cooks.py create mode 100644 tests/unit_tests/deseq2_core/test_save_pipeline_results.py create mode 100644 tests/unit_tests/deseq2_core/test_size_factors.py create mode 100644 tests/unit_tests/fed_algorithms/__init__.py create mode 100644 tests/unit_tests/fed_algorithms/fed_IRLS/__init__.py create mode 100644 tests/unit_tests/fed_algorithms/fed_IRLS/fed_IRLS_tester.py create mode 100644 tests/unit_tests/fed_algorithms/fed_IRLS/irls_test_pipe.py create mode 100644 tests/unit_tests/fed_algorithms/fed_IRLS/test_IRLS_utils.py create mode 100644 tests/unit_tests/fed_algorithms/fed_IRLS/test_irls.py create mode 100644 tests/unit_tests/fed_algorithms/fed_IRLS_PQN_tester.py create mode 100644 tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/__init__.py create mode 100644 tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/fed_pqn_tester.py create mode 100644 tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/pqn_test_pipe.py create mode 100644 tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_fed_prox_newton_utils.py create mode 100644 tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_pqn.py create mode 100644 tests/unit_tests/fed_algorithms/trimmed_mean_strategy/__init__.py create mode 100644 tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/__init__.py create mode 100644 tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/description.md create mode 100644 tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/opener.py create mode 100644 tests/unit_tests/fed_algorithms/trimmed_mean_strategy/test_trimmed_mean_strategy.py create mode 100644 tests/unit_tests/fed_algorithms/trimmed_mean_strategy/trimmed_mean_strategy.py create mode 100644 tests/unit_tests/layers/__init__.py create mode 100644 tests/unit_tests/layers/layers_tester.py create mode 100644 tests/unit_tests/layers/opener/__init__.py create mode 100644 tests/unit_tests/layers/opener/description.md create mode 100644 tests/unit_tests/layers/opener/opener.py create mode 100644 tests/unit_tests/layers/test_reconstruct_layers.py create mode 100644 tests/unit_tests/layers/test_utils.py create mode 100644 tests/unit_tests/layers/utils.py create mode 100644 tests/unit_tests/unit_test_helpers/__init__.py create mode 100644 tests/unit_tests/unit_test_helpers/levels.py create mode 100644 tests/unit_tests/unit_test_helpers/pass_on_first_shared_state.py create mode 100644 tests/unit_tests/unit_test_helpers/set_local_reference.py create mode 100644 tests/unit_tests/unit_test_helpers/unit_tester.py create mode 100644 tests/unit_tests/utils/test_mle.py diff --git a/.ci/self_hosted_conda_creation.sh b/.ci/self_hosted_conda_creation.sh new file mode 100755 index 0000000..37a57a3 --- /dev/null +++ b/.ci/self_hosted_conda_creation.sh @@ -0,0 +1,25 @@ +#!/bin/sh + +# Check that the env folder exists or create it +if [ -d ~/envs ] +then + echo "Found existing envs folder" +else + echo "Did not find envs folder, creating" + mkdir ~/envs +fi + +# Check if conda environment exists. if it does, remove it. +if [ -d "~/envs/fedomics_python_$1" ] +then + echo "Found existing fedomics conda environment, removing" + conda env remove --prefix "~/envs/fedomics_python_$1" -y +fi +conda init bash +. ~/.bashrc +# +echo "Creating environment" +yes | conda create --prefix "~/envs/fedomics_python_$1" python="$1" +echo "Created env fedomics_python_$1" +eval "$(conda shell.bash hook)" +conda activate "~/envs/fedomics_python_$1" diff --git a/.ci/self_hosted_conda_removal.sh b/.ci/self_hosted_conda_removal.sh new file mode 100755 index 0000000..675caf6 --- /dev/null +++ b/.ci/self_hosted_conda_removal.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +# Check if conda environment exists. if it does, remove it. +if [ -d "~/envs/fedomics_python_$1" ] +then + echo "Found existing fedomics conda environment, removing" + conda env remove --prefix "~/envs/fedomics_python_$1" -y +fi diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..5009781 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,28 @@ +--- +name: Bug report +about: Create a report to help us improve +title: "[BUG] " +labels: bug +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Provide snippets of code and steps on how to reproduce the behavior. +Please also specify the version you are using. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Desktop (please complete the following information):** + - OS: [e.g. iOS] + - Version [e.g. 0.02] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..bbcbbe7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..6222fe2 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,16 @@ + + +#### Reference Issue or PRs + + + +#### What does your PR implement? Be specific. diff --git a/.github/workflows/pr_push_validation.yml b/.github/workflows/pr_push_validation.yml new file mode 100644 index 0000000..69d03a4 --- /dev/null +++ b/.github/workflows/pr_push_validation.yml @@ -0,0 +1,68 @@ +name: Python dev + +on: + pull_request: + push: + branches: + - main + +jobs: + testing: + runs-on: ubuntu-latest + strategy: + matrix: + python: ["3.10", "3.11", "3.12"] + name: Testing Python ${{ matrix.python }} + steps: + - name: Checkout fedpydeseq2 + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install Poetry + run: | + python --version + pip install poetry==1.8.2 + + - name: Install dependencies + run: | + which python + python --version + poetry install --with testing + + - name: Download data + run: | + mkdir -p /opt/conda + wget https://repo.anaconda.com/miniconda/Miniconda3-py39_24.5.0-0-Linux-x86_64.sh -O /opt/conda/miniconda.sh + bash /opt/conda/miniconda.sh -b -p /opt/miniconda + poetry run fedpydeseq2-download-data --only_luad --raw_data_output_path /home/runner/work/fedpydeseq2/fedpydeseq2/data/raw --conda_activate_path /opt/miniconda/bin/activate + ls /home/runner/work/fedpydeseq2/fedpydeseq2/data/raw + ls /home/runner/work/fedpydeseq2/fedpydeseq2/data/raw/tcga + - name: Testing + run: | + poetry run pytest -v tests -m "not self_hosted_slow and not self_hosted_fast and not local and not docker" + + linting: + runs-on: ubuntu-latest + name: Test Linting + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install Poetry + run: pip install poetry==1.8.2 + + - name: Install dependencies + run: | + which python + python --version + poetry install --with linting + + - name: Pre-commit checks + run: | + poetry run pre-commit run --all-files --show-diff-on-failure diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..89b17de --- /dev/null +++ b/.gitignore @@ -0,0 +1,184 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# VSCode +.vscode/ + +# DS_Store +.DS_Store + +# Data +tests/deseq2_end_to_end/local-worker/* +local-worker/* +fedpydeseq2/substra_utils/credentials/* +!fedpydeseq2/substra_utils/credentials/credentials-template.yaml +!fedpydeseq2/substra_utils/credentials/dataset-datasamples-keys-template.yaml +tests/local-worker/* +tmp_substrafl* + +/experiments/credentials/ +/paper_experiments/gsea/gsea.sh +*.lock +!poetry.lock + +tests/datasets_parent_dir.txt + +data/raw/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..223e5b9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +default_language_version: + python: python3.11 +repos: + - repo: https://github.com/sirosen/check-jsonschema + rev: 0.27.0 + hooks: + - id: check-github-actions + - id: check-github-workflows + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + name: Trim trailing whitespace + - id: end-of-file-fixer + name: Fix end of files + exclude: \.ipynb$ + - repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black + additional_dependencies: ["click==8.0.4"] + args: # arguments to configure black + - --line-length=88 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.5 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.7.0 + hooks: + - id: mypy + exclude: ^(tests/|docs/source/conf.py|datasets/) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b48d852 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Owkin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e4ea853 --- /dev/null +++ b/README.md @@ -0,0 +1,115 @@ +# FedPyDESeq2: Putting the Fed in the Py + +## Setup + +### Package installation through PyPI + +You can install the package from PyPI using the following command: + +```bash +pip install fedpydeseq2 +``` + + +### Package installation for developpers + +#### 0 - Clone the repository + +Start by cloning this repository + +```bash +git clone git@github.com:owkin/fedpydeseq2.git +``` + +#### 1 - Create a conda environment with python 3.9+ + +``` +conda create -n fedpydeseq2 python=3.11 # or a version compatible +conda activate fedpydeseq2 +``` + +#### 2 - Install `poetry` + +Run + +``` +conda install pip +pip install poetry==1.8.2 +``` + +and test the installation with `poetry --version`. + + + +#### 3 - Install the package and its dependencies using `poetry` + +`cd` to the root of the repository and run + +``` +poetry install --with linting,testing +``` + +#### 4 - Download the data to run the tests on + +To download the data, `cd` to the root of the repository run this command. + +```bash +fedpydeseq2-download-data --raw_data_output_path data/raw +``` + +This way, you create a `data/raw` subdirectory in the directory containing all the necessary data. If you want to modify +the location of this raw data, you can in the following way. Run this command instead: + +```bash +fedpydeseq2-download-data --raw_data_output_path MY_RAW_PATH +``` + + +And create a file in the `tests` directory named `paths.json` containing +- A `raw_data` field with the path to the raw data `MY_RAW_PATH` +- An optional `assets_tcga` field with the path to the directory containing the `opener.py` file and its description (by default present in the fedpydeseq2_datasets module, so no need to specify this unless you need to modify the opener); +- An optional `processed_data` field with the path to the directory where you want to save processed data. Note that this is used +only if you want to run tests locally without reprocessing the data during each test session (test marked with the `local` marker). + + + + +#### 5 - Install `pre-commit` hooks + +Still in the root of the repository, run + +`pre-commit install` + +You are now ready to contribute. + +## CI on a self-hosted runner +Tests are run using a self-hosted runner. To add a self-hosted runner, instantiate the machine +you want to use as a runner, go to the repository settings, then to the `Actions` tab, and click on +`Add runner`. Follow the instructions to install the runner on the machine you want +to use as a self-hosted runner. + +Make sure to label the self-hosted runner with the label "fedpydeseq2-self-hosted" so that +the CI workflow can find it. + +### Docker CI +The docker mode is only tested manually. To test it, first run `poetry build` +in order to create a wheel in the `dist` folder. Then launch in a tmux the +following: +``` +pytest -m "docker" -s +``` +The `-s` option enables to print all the logs/outputs continuously. Otherwise, these +outputs appear only once the test is done. As the test takes time, it's better to +print them continuously. + +## Running on a real Substra environment + +### running the CP +To run a compute plan on an environment with the substra front-end, you need first to generate token in each of the +substra nodes. Then you need to duplicate +[credentials-template.yaml](fedpydeseq2/substra_utils/credentials/credentials-template.yaml) +into a new file +[credentials.yaml](fedpydeseq2/substra_utils/credentials/credentials.yaml) and fill in the +tokens. You should not need to rebuild the wheel manually by running +`poetry build` as the script will try to do it for you, but watch out for +related error message when executing the file. diff --git a/fedpydeseq2/__init__.py b/fedpydeseq2/__init__.py new file mode 100644 index 0000000..2c5167d --- /dev/null +++ b/fedpydeseq2/__init__.py @@ -0,0 +1 @@ +from fedpydeseq2.core.deseq2_strategy import DESeq2Strategy diff --git a/fedpydeseq2/core/deseq2_core/__init__.py b/fedpydeseq2/core/deseq2_core/__init__.py new file mode 100644 index 0000000..0e7bb89 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/__init__.py @@ -0,0 +1,7 @@ +"""Module containing the core of the DESeq2 pipeline. + +It contains all the Mixin classes corresponding to the step of the pipeline. +The main class defined in this module is the DESeq2FullPipe class. +""" + +from fedpydeseq2.core.deseq2_core.deseq2_full_pipe import DESeq2FullPipe diff --git a/fedpydeseq2/core/deseq2_core/build_design_matrix/__init__.py b/fedpydeseq2/core/deseq2_core/build_design_matrix/__init__.py new file mode 100644 index 0000000..39145e7 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/build_design_matrix/__init__.py @@ -0,0 +1,5 @@ +"""Module to regroup the steps to build the design matrix for DESeq2.""" + +from fedpydeseq2.core.deseq2_core.build_design_matrix.build_design_matrix import ( + BuildDesignMatrix, +) diff --git a/fedpydeseq2/core/deseq2_core/build_design_matrix/build_design_matrix.py b/fedpydeseq2/core/deseq2_core/build_design_matrix/build_design_matrix.py new file mode 100644 index 0000000..79fb386 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/build_design_matrix/build_design_matrix.py @@ -0,0 +1,144 @@ +from fedpydeseq2.core.deseq2_core.build_design_matrix.substeps import ( + AggMergeDesignColumnsBuildContrast, +) +from fedpydeseq2.core.deseq2_core.build_design_matrix.substeps import ( + AggMergeDesignLevels, +) +from fedpydeseq2.core.deseq2_core.build_design_matrix.substeps import LocGetLocalFactors +from fedpydeseq2.core.deseq2_core.build_design_matrix.substeps import ( + LocOderDesignComputeLogMean, +) +from fedpydeseq2.core.deseq2_core.build_design_matrix.substeps import LocSetLocalDesign +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class BuildDesignMatrix( + AggMergeDesignColumnsBuildContrast, + AggMergeDesignLevels, + LocGetLocalFactors, + LocSetLocalDesign, + LocOderDesignComputeLogMean, +): + """Mixin class to implement the computation of the design matrix. + + Methods + ------- + build_design_matrix + The method to build the design matrix, that must be used in the main + pipeline. + + check_design_matrix + The method to check the design matrix, that must be used in the main + pipeline while we are testing. + """ + + def build_design_matrix( + self, train_data_nodes, aggregation_node, local_states, round_idx, clean_models + ): + """Build the design matrix. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: + The current round + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + shared_states: dict + Shared states containing the necessary local information to start + the next step of the pipeline, which is computing the size factors. + They contain a "log_means" key and a "n_samples" key. + + round_idx: int + The updated round + + """ + # ---- For each design factor, get the list of each center's levels ---- # + if len(local_states) == 0: + # In that case, there is no reference dds, and this is the first step of + # The pipeline + input_local_states = None + else: + # In this case, there was already a step before, and we need to propagate + # the local states + input_local_states = local_states + + local_states, shared_states, round_idx = local_step( + local_method=self.get_local_factors, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=input_local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Computing local design factor levels", + clean_models=clean_models, + ) + + # ---- For each design factor, merge the list of unique levels ---- # + + design_levels_aggregated_state, round_idx = aggregation_step( + aggregation_method=self.merge_design_levels, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Merging design levels", + clean_models=clean_models, + ) + + # ---- Initialize design matrices in each center ---- # + + local_states, shared_states, round_idx = local_step( + local_method=self.set_local_design, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=design_levels_aggregated_state, + aggregation_id=aggregation_node.organization_id, + description="Setting local design matrices", + clean_models=clean_models, + ) + + # ---- Merge design columns ---- # + + design_columns_aggregated_state, round_idx = aggregation_step( + aggregation_method=self.merge_design_columns_and_build_contrast, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Merge local design matrix columns", + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.order_design_cols_compute_local_log_mean, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + round_idx=round_idx, + input_shared_state=design_columns_aggregated_state, + aggregation_id=aggregation_node.organization_id, + description="Computing local log means", + clean_models=clean_models, + ) + + return local_states, shared_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/build_design_matrix/substeps.py b/fedpydeseq2/core/deseq2_core/build_design_matrix/substeps.py new file mode 100644 index 0000000..88b6ac9 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/build_design_matrix/substeps.py @@ -0,0 +1,248 @@ +"""Module containing the substeps for the computation of design matrices. + +This module contains all these substeps as mixin classes. +""" + + +import anndata as ad +import numpy as np +import pandas as pd +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils import build_contrast +from fedpydeseq2.core.utils import build_design_matrix +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class AggMergeDesignColumnsBuildContrast: + """Mixin to merge the columns of the design matrices and build contrast.""" + + design_factors: list[str] + continuous_factors: list[str] | None + contrast: list[str] | None + + @remote + @log_remote + def merge_design_columns_and_build_contrast(self, shared_states): + """Merge the columns of the design matrices and build constrasts. + + Parameters + ---------- + shared_states : list + List of results (dictionaries of design columns) from training nodes. + + Returns + ------- + dict + Shared state containing: + - merged_columns: the names of the columns that the local design matrices + should have. + - contrast: the contrast (in a list of strings form) to be used for the + DESeq2 model. + """ + merged_columns = pd.Index([]) + + for state in shared_states: + merged_columns = merged_columns.union(state["design_columns"]) + + # We now also have everything to compute the contrasts + contrast = build_contrast( + self.design_factors, + merged_columns, + self.continuous_factors, + self.contrast, + ) + + return {"merged_columns": merged_columns, "contrast": contrast} + + +class AggMergeDesignLevels: + """Mixin to merge the levels of the design factors.""" + + categorical_factors: list[str] + + @remote + @log_remote + def merge_design_levels(self, shared_states): + """Merge the levels of the design factors. + + Parameters + ---------- + shared_states : list + List of results (dictionaries of local_levels) from training nodes. + + Returns + ------- + dict + Dictionary of unique levels for each factor. + """ + # merge levels + merged_levels = {factor: set() for factor in self.categorical_factors} + for factor in self.categorical_factors: + for state in shared_states: + merged_levels[factor] = set(state["local_levels"][factor]).union( + merged_levels[factor] + ) + + return { + "merged_levels": { + factor: np.array(list(levels)) + for factor, levels in merged_levels.items() + } + } + + +class LocGetLocalFactors: + """Mixin to get the list of unique levels for each categorical design factor.""" + + categorical_factors: list[str] + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_factors( + self, data_from_opener, shared_state=None + ): # pylint: disable=unused-argument + """Get the list of unique levels for each categorical design factor. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Copied in local anndata objects. + + shared_state : None, optional + Not used. + + Returns + ------- + dict + A dictionary of unique local levels for each factor. + """ + self.local_adata = data_from_opener.copy() + return { + "local_levels": { + factor: self.local_adata.obs[factor].unique() + for factor in self.categorical_factors + } + } + + +class LocSetLocalDesign: + """Mixin to set the design matrices in centers.""" + + local_adata: ad.AnnData + design_factors: list[str] + continuous_factors: list[str] | None + ref_levels: dict[str, str] | None + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_local_design( + self, + data_from_opener, + shared_state, + ): + # pylint: disable=unused-argument + """ + Set the design matrices in centers. + + Returns their columns in order to harmonize them. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state with a "design_columns" key containing a dictionary with, for + each design factor, the names of its unique levels. + + Returns + ------- + dict + Local design columns. + """ + self.local_adata.obsm["design_matrix"] = build_design_matrix( + metadata=self.local_adata.obs, + design_factors=self.design_factors, + continuous_factors=self.continuous_factors, + levels=shared_state["merged_levels"], + ref_levels=self.ref_levels, + ) + return {"design_columns": self.local_adata.obsm["design_matrix"].columns} + + +class LocOderDesignComputeLogMean: + """Mixin to order design cols and compute the local log mean. + + Attributes + ---------- + local_adata : ad.AnnData + The local AnnData. + + Methods + ------- + order_design_cols_compute_local_log_mean + Order design columns and compute the local log mean. + + """ + + local_adata: ad.AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def order_design_cols_compute_local_log_mean( + self, data_from_opener, shared_state=None + ): + """Order design columns and compute the local log mean. + + This function also sets the contrast in the local AnnData, + and saves the number of parameters in the uns field. + + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state with: + - "merged_columns" a set containing the names of columns that the design + matrix should have. + - "contrast" the contrast to be used for the DESeq2 model. + + Returns + ------- + dict + Local mean of logs and number of samples. + """ + #### ----Step 1: Order design columns---- #### + + self.local_adata.uns["contrast"] = shared_state["contrast"] + + for col in shared_state["merged_columns"]: + if col not in self.local_adata.obsm["design_matrix"].columns: + self.local_adata.obsm["design_matrix"][col] = 0 + + # Reorder columns for consistency + self.local_adata.obsm["design_matrix"] = self.local_adata.obsm["design_matrix"][ + shared_state["merged_columns"] + ] + + # Save the number of params in an uns field for easy access + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + + #### ----Step 2: Compute local log mean---- #### + + with np.errstate(divide="ignore"): # ignore division by zero warnings + return { + "log_mean": np.log(data_from_opener.X).mean(axis=0), + "n_samples": data_from_opener.n_obs, + } diff --git a/fedpydeseq2/core/deseq2_core/compute_cook_distance/__init__.py b/fedpydeseq2/core/deseq2_core/compute_cook_distance/__init__.py new file mode 100644 index 0000000..68f1a75 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/compute_cook_distance/__init__.py @@ -0,0 +1,5 @@ +"""Pipe step computing the cooks distance.""" + +from fedpydeseq2.core.deseq2_core.compute_cook_distance.compute_cook_distance import ( + ComputeCookDistances, +) diff --git a/fedpydeseq2/core/deseq2_core/compute_cook_distance/compute_cook_distance.py b/fedpydeseq2/core/deseq2_core/compute_cook_distance/compute_cook_distance.py new file mode 100644 index 0000000..061f34b --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/compute_cook_distance/compute_cook_distance.py @@ -0,0 +1,131 @@ +from fedpydeseq2.core.deseq2_core.compute_cook_distance.substeps import ( + AggComputeDispersionForCook, +) +from fedpydeseq2.core.deseq2_core.compute_cook_distance.substeps import ( + LocComputeSqerror, +) +from fedpydeseq2.core.deseq2_core.compute_cook_distance.substeps import ( + LocGetNormedCounts, +) +from fedpydeseq2.core.fed_algorithms import ComputeTrimmedMean +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeCookDistances( + ComputeTrimmedMean, + LocComputeSqerror, + LocGetNormedCounts, + AggComputeDispersionForCook, +): + """Mixin class to compute Cook's distances. + + Methods + ------- + compute_cook_distance + The method to compute Cook's distances. + + """ + + trimmed_mean_num_iter: int + + def compute_cook_distance( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """ + Compute Cook's distances. + + Parameters + ---------- + train_data_nodes: list + list of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: list[dict] + Local states. Required to propagate intermediate results. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states. The new local state contains Cook's distances. + + dispersion_for_cook_shared_state: dict + Shared state with the dispersion values for Cook's distances, in a + "cooks_dispersions" key. + + round_idx: int + The updated round index. + + """ + local_states, agg_shared_state, round_idx = self.compute_trim_mean( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + layer_used="normed_counts", + mode="cooks", + trim_ratio=None, + n_iter=self.trimmed_mean_num_iter, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.local_compute_sqerror, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=agg_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute local sqerror", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, agg_shared_state, round_idx = self.compute_trim_mean( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + layer_used="sqerror", + mode="cooks", + trim_ratio=None, + n_iter=self.trimmed_mean_num_iter, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.local_get_normed_count_means, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=agg_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Get normed count means", + round_idx=round_idx, + clean_models=clean_models, + ) + + dispersion_for_cook_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_compute_dispersion_for_cook, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Compute dispersion for Cook distances", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, dispersion_for_cook_shared_state, round_idx diff --git a/fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py b/fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py new file mode 100644 index 0000000..508118e --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py @@ -0,0 +1,116 @@ +import pandas as pd +from anndata import AnnData +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import prepare_cooks_agg +from fedpydeseq2.core.utils.layers import prepare_cooks_local +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.layers.build_layers import set_sqerror_layer +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocComputeSqerror: + """Compute the squared error between the normalized counts and the trimmed mean.""" + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_compute_sqerror( + self, + data_from_opener, + shared_state=dict, + ) -> None: + """ + Compute the squared error between the normalized counts and the trimmed mean. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict, optional + Results to save in the local states. + + """ + cell_means = shared_state["trimmed_mean_normed_counts"] + if isinstance(cell_means, pd.DataFrame): + cell_means.index = self.local_adata.var_names + self.local_adata.varm["cell_means"] = cell_means + else: + # In this case, the cell means are not computed per + # level but overall + self.local_adata.varm["cell_means"] = cell_means + set_sqerror_layer(self.local_adata) + + +class LocGetNormedCounts: + """Get the mean of the normalized counts.""" + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + @prepare_cooks_local + def local_get_normed_count_means( + self, + data_from_opener, + shared_state=dict, + ) -> dict: + """ + Send local normed counts means. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict, optional + Dictionary with the following keys: + - varEst: variance estimate for Cook's distance calculation + + Returns + ------- + dict + Because of the decorator, dictionary with the following keys: + - mean_normed_counts: mean of the normalized counts + - n_samples: number of samples + - varEst: variance estimate + + """ + return {} + + +class AggComputeDispersionForCook: + """Compute the dispersion for Cook's distance calculation.""" + + @remote + @log_remote + @prepare_cooks_agg + def agg_compute_dispersion_for_cook( + self, + shared_states: list[dict], + ) -> dict: + """ + Compute the dispersion for Cook's distance calculation. + + Parameters + ---------- + shared_states : list[dict] + list of shared states with the following keys: + - mean_normed_counts: mean of the normalized counts + - n_samples: number of samples + - varEst: variance estimate + + Returns + ------- + dict + Because it is decorated, the dictionary will have the following key: + - cooks_dispersions: dispersion values + + """ + return {} diff --git a/fedpydeseq2/core/deseq2_core/compute_size_factors/__init__.py b/fedpydeseq2/core/deseq2_core/compute_size_factors/__init__.py new file mode 100644 index 0000000..f7c12f4 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/compute_size_factors/__init__.py @@ -0,0 +1,5 @@ +"""Module to implement the computation of size factors.""" + +from fedpydeseq2.core.deseq2_core.compute_size_factors.compute_size_factors import ( + ComputeSizeFactors, +) diff --git a/fedpydeseq2/core/deseq2_core/compute_size_factors/compute_size_factors.py b/fedpydeseq2/core/deseq2_core/compute_size_factors/compute_size_factors.py new file mode 100644 index 0000000..9d29c0f --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/compute_size_factors/compute_size_factors.py @@ -0,0 +1,104 @@ +"""Module containing the steps for the computation of rough dispersions.""" + +from fedpydeseq2.core.deseq2_core.compute_size_factors.substeps import AggLogMeans +from fedpydeseq2.core.deseq2_core.compute_size_factors.substeps import ( + LocSetSizeFactorsComputeGramAndFeatures, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeSizeFactors( + AggLogMeans, + LocSetSizeFactorsComputeGramAndFeatures, +): + """Mixin class to implement the computation of size factors. + + Methods + ------- + compute_size_factors + The method to compute the size factors, that must be used in the main + pipeline. It sets the size factors in the local AnnData and computes the + Gram matrix and feature vector in order to start the next step, i.e., + the computation of rough dispersions. + + """ + + def compute_size_factors( + self, + train_data_nodes, + aggregation_node, + local_states, + shared_states, + round_idx, + clean_models, + ): + """Compute size factors. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + shared_states: list + Shared states which are the output of the "build_design_matrix" step. + These shared states contain the following fields: + - "log_mean" : the log mean of the gene expressions. + - "n_samples" : the number of samples in each client. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + shared_states: dict + Shared states which contain the local information necessary to start + running the compute rough dispersions step. These shared states contain + a "local_gram_matrix" and a "local_features" key. + + round_idx: int + The updated round index. + + """ + # ---- Aggregate means of log gene expressions ----# + + log_mean_aggregated_state, round_idx = aggregation_step( + aggregation_method=self.aggregate_log_means, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Aggregating local log means", + clean_models=clean_models, + ) + + # ---- Set local size factors and return next shared states ---- # + + local_states, shared_states, round_idx = local_step( + local_method=self.local_set_size_factors_compute_gram_and_features, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=log_mean_aggregated_state, + aggregation_id=aggregation_node.organization_id, + description=( + "Setting local size factors and " + "computing Gram matrices and feature vectors" + ), + clean_models=clean_models, + ) + + return local_states, shared_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/compute_size_factors/substeps.py b/fedpydeseq2/core/deseq2_core/compute_size_factors/substeps.py new file mode 100644 index 0000000..56d0dbb --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/compute_size_factors/substeps.py @@ -0,0 +1,108 @@ +"""Module containing the substeps for the computation of size factors.""" + + +import anndata as ad +import numpy as np +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.aggregation import aggregate_means +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class AggLogMeans: + """Mixin to compute the global mean given the local results.""" + + @remote + @log_remote + def aggregate_log_means(self, shared_states): + """Compute the global mean given the local results. + + Parameters + ---------- + shared_states : list + List of results (local_mean, n_samples) from training nodes. + + Returns + ------- + dict + Global mean of log counts, and new all-zero genes if in refit mode. + """ + tot_mean = aggregate_means( + [state["log_mean"] for state in shared_states], + [state["n_samples"] for state in shared_states], + ) + + return {"global_log_mean": tot_mean} + + +class LocSetSizeFactorsComputeGramAndFeatures: + """Mixin to set local size factors and return local Gram matrices and features. + + This Mixin implements the method to perform the transition between the + compute_size_factors and compute_rough_dispersions steps. It sets the size + factors in the local AnnData and computes the Gram matrix and feature vector. + + Methods + ------- + local_set_size_factors_compute_gram_and_features + The method to set the size factors and compute the Gram matrix and feature. + + """ + + local_adata: ad.AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_set_size_factors_compute_gram_and_features( + self, + data_from_opener, + shared_state, + ) -> dict: + # pylint: disable=unused-argument + """Set local size factor and compute Gram matrix and feature vector. + + This is a local method, used to fit the rough dispersions. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state containing the "global_log_mean" key. + + Returns + ------- + dict + Local gram matrices and feature vectors to be shared via shared_state to + the aggregation node. + """ + #### ---- Compute size factors ---- #### + + global_log_means = shared_state["global_log_mean"] + # Filter out genes with -∞ log means + filtered_genes = ~np.isinf(global_log_means) + + log_ratios = ( + np.log(self.local_adata.X[:, filtered_genes]) + - global_log_means[filtered_genes] + ) + # Compute sample-wise median of log ratios + log_medians = np.median(log_ratios, axis=1) + # Return raw counts divided by size factors (exponential of log ratios) + # and size factors + self.local_adata.obsm["size_factors"] = np.exp(log_medians) + self.local_adata.layers["normed_counts"] = ( + self.local_adata.X / self.local_adata.obsm["size_factors"][:, None] + ) + + design = self.local_adata.obsm["design_matrix"].values + + return { + "local_gram_matrix": design.T @ design, + "local_features": design.T @ self.local_adata.layers["normed_counts"], + } diff --git a/fedpydeseq2/core/deseq2_core/deseq2_full_pipe.py b/fedpydeseq2/core/deseq2_core/deseq2_full_pipe.py new file mode 100644 index 0000000..8600cda --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_full_pipe.py @@ -0,0 +1,189 @@ +from loguru import logger +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef + +from fedpydeseq2.core.deseq2_core.build_design_matrix import BuildDesignMatrix +from fedpydeseq2.core.deseq2_core.compute_cook_distance import ComputeCookDistances +from fedpydeseq2.core.deseq2_core.compute_size_factors import ComputeSizeFactors +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions import DESeq2LFCDispersions +from fedpydeseq2.core.deseq2_core.deseq2_stats.deseq2_stats import DESeq2Stats +from fedpydeseq2.core.deseq2_core.replace_outliers import ReplaceCooksOutliers +from fedpydeseq2.core.deseq2_core.replace_refitted_values import ReplaceRefittedValues +from fedpydeseq2.core.deseq2_core.save_pipeline_results import SavePipelineResults + + +class DESeq2FullPipe( + BuildDesignMatrix, + ComputeSizeFactors, + DESeq2LFCDispersions, + ComputeCookDistances, + ReplaceCooksOutliers, + ReplaceRefittedValues, + DESeq2Stats, + SavePipelineResults, +): + """A Mixin class to run the full DESeq2 pipeline. + + Methods + ------- + run_deseq_pipe + The method to run the full DESeq2 pipeline. + """ + + def run_deseq_pipe( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + local_states: dict[str, LocalStateRef], + round_idx: int = 0, + clean_models: bool = True, + clean_last_model: bool = False, + ): + """Run the DESeq2 pipeline. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + local_states : dict[str, LocalStateRef] + Local states. + round_idx : int + Round index. + clean_models : bool + Whether to clean the models after the computation. (default: ``True``). + Note that as intermediate steps are very memory consuming, it is recommended + to clean the models after each step. + clean_last_model : bool + Whether to clean the last model. (default: ``False``). + """ + #### Build design matrices #### + + logger.info("Building design matrices...") + + local_states, log_mean_shared_states, round_idx = self.build_design_matrix( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + logger.info("Finished building design matrices.") + + #### Compute size factors #### + # Note: in refit mode, this doesn't recompute size factors, + # just the log features + + logger.info("Computing size factors...") + + ( + local_states, + gram_features_shared_states, + round_idx, + ) = self.compute_size_factors( + train_data_nodes, + aggregation_node, + local_states, + shared_states=log_mean_shared_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + logger.info("Finished computing size factors.") + + #### Compute LFC and dispersions #### + + logger.info("Running LFC and dispersions.") + + local_states, round_idx = self.run_deseq2_lfc_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + gram_features_shared_states=gram_features_shared_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + logger.info("Finished running LFC and dispersions.") + + logger.info("Computing Cook distances...") + + ( + local_states, + cooks_shared_state, + round_idx, + ) = self.compute_cook_distance( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + logger.info("Finished computing Cook distances.") + + #### Refit cooks if necessary #### + if self.refit_cooks: + logger.info("Refitting Cook outliers...") + ( + local_states, + gram_features_shared_states, + round_idx, + ) = self.replace_outliers( + train_data_nodes, + aggregation_node, + local_states, + cooks_shared_state, + round_idx, + clean_models=clean_models, + ) + + local_states, round_idx = self.run_deseq2_lfc_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + gram_features_shared_states=gram_features_shared_states, + round_idx=round_idx, + clean_models=clean_models, + refit_mode=True, + ) + # Replace values in the main ``local_adata`` object + local_states, round_idx = self.replace_refitted_values( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + logger.info("Finished refitting Cook outliers.") + + #### Compute DESeq2 statistics #### + + logger.info("Running DESeq2 statistics.") + + local_states, round_idx = self.run_deseq2_stats( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + logger.info("Finished running DESeq2 statistics.") + + # Build the results that will be downloaded at the end of the pipeline. + + logger.info("Saving pipeline results.") + self.save_pipeline_results( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + logger.info("Finished saving pipeline results.") diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/__init__.py new file mode 100644 index 0000000..3e369a7 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/__init__.py @@ -0,0 +1,3 @@ +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.deseq2_lfc_dispersions import ( + DESeq2LFCDispersions, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/__init__.py new file mode 100644 index 0000000..aa7b9f6 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/__init__.py @@ -0,0 +1,5 @@ +"""Module containing the mixin class to compute MAP dispersions.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_MAP_dispersions.compute_MAP_dispersions import ( # noqa: E501 + ComputeMAPDispersions, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/compute_MAP_dispersions.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/compute_MAP_dispersions.py new file mode 100644 index 0000000..fa71251 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/compute_MAP_dispersions.py @@ -0,0 +1,101 @@ +"""Main module to compute dispersions by minimizing the MLE using a grid search.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_MAP_dispersions.substeps import ( # noqa: E501 + LocFilterMAPDispersions, +) +from fedpydeseq2.core.fed_algorithms import ComputeDispersionsGridSearch +from fedpydeseq2.core.utils import local_step + + +class ComputeMAPDispersions( + LocFilterMAPDispersions, + ComputeDispersionsGridSearch, +): + """ + Mixin class to implement the computation of MAP dispersions. + + Methods + ------- + fit_MAP_dispersions + A method to fit the MAP dispersions and filter them. + The filtering is done by removing the dispersions that are too far from the + trend curve. + + """ + + def fit_MAP_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + shared_state, + round_idx, + clean_models, + refit_mode: bool = False, + ): + """Fit MAP dispersions, and apply filtering. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + shared_state: dict + Contains the output of the trend fitting, + that is a dictionary with a "fitted_dispersion" field containing + the fitted dispersions from the trend curve, a "prior_disp_var" field + containing the prior variance of the dispersions, and a "_squared_logres" + field containing the squared residuals of the trend fitting. + + round_idx: int + The current round. + + clean_models: bool + Whether to clean the models after the computation. + + refit_mode: bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, the pipeline will be run on `refit_adata`s instead of + `local_adata`s. (default: False). + + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The updated round index. + """ + local_states, shared_state, round_idx = self.fit_dispersions( + train_data_nodes, + aggregation_node, + local_states, + shared_state=shared_state, + round_idx=round_idx, + clean_models=clean_models, + fit_mode="MAP", + refit_mode=refit_mode, + ) + + # Filter the MAP dispersions. + local_states, _, round_idx = local_step( + local_method=self.filter_outlier_genes, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_state, + aggregation_id=aggregation_node.organization_id, + description="Filter MAP dispersions.", + round_idx=round_idx, + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + return local_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/substeps.py new file mode 100644 index 0000000..50f9a0d --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/substeps.py @@ -0,0 +1,56 @@ +import numpy as np +from anndata import AnnData +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocFilterMAPDispersions: + """Mixin to filter MAP dispersions and obtain the final dispersion estimates.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def filter_outlier_genes( + self, + data_from_opener, + shared_state, + refit_mode: bool = False, + ) -> None: + """Filter out outlier genes. + + Avoids shrinking the dispersions of genes that are too far from the trend curve. + + Parameters + ---------- + data_from_opener : ad.AnnData + Not used. + + shared_state : dict + Contains: + - "MAP_dispersions": MAP dispersions, + + refit_mode : bool + Whether to run the pipeline on `refit_adata`s instead of `local_adata`s. + (default: False). + """ + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + adata.varm["MAP_dispersions"] = shared_state["MAP_dispersions"].copy() + + adata.varm["dispersions"] = adata.varm["MAP_dispersions"].copy() + adata.varm["_outlier_genes"] = np.log( + adata.varm["genewise_dispersions"] + ) > np.log(adata.varm["fitted_dispersions"]) + 2 * np.sqrt( + adata.uns["_squared_logres"] + ) + adata.varm["dispersions"][adata.varm["_outlier_genes"]] = adata.varm[ + "genewise_dispersions" + ][adata.varm["_outlier_genes"]] diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/__init__.py new file mode 100644 index 0000000..bd69eb0 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/__init__.py @@ -0,0 +1,3 @@ +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_dispersion_prior.compute_dispersion_prior import ( # noqa: E501 + ComputeDispersionPrior, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/compute_dispersion_prior.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/compute_dispersion_prior.py new file mode 100644 index 0000000..bd8deb3 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/compute_dispersion_prior.py @@ -0,0 +1,100 @@ +"""Module containing the steps for fitting the dispersion trend.""" +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_dispersion_prior.substeps import ( # noqa: E501 + AggFitDispersionTrendAndPrior, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_dispersion_prior.substeps import ( # noqa: E501 + LocGetMeanDispersionAndMean, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_dispersion_prior.substeps import ( # noqa: E501 + LocUpdateFittedDispersions, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeDispersionPrior( + AggFitDispersionTrendAndPrior, + LocGetMeanDispersionAndMean, + LocUpdateFittedDispersions, +): + """Mixin class to implement the fit of the dispersion trend. + + Methods + ------- + compute_dispersion_prior + The method to fit the dispersion trend. + + """ + + def compute_dispersion_prior( + self, + train_data_nodes, + aggregation_node, + local_states, + genewise_dispersions_shared_state, + round_idx, + clean_models, + ): + """Fit the dispersion trend. + + Parameters + ---------- + train_data_nodes: list + list of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + genewise_dispersions_shared_state: dict + Shared state with a "genewise_dispersions" key. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + dispersion_trend_share_state: dict + Shared states with: + - "fitted_dispersions": the fitted dispersions, + - "prior_disp_var": the prior dispersion variance. + + round_idx: int + The updated round index. + + """ + # --- Return means and dispersions ---# + # TODO : merge this step with the last steps from genewise dispersion + local_states, shared_states, round_idx = local_step( + local_method=self.get_local_mean_and_dispersion, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=genewise_dispersions_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Get local means and dispersions", + clean_models=clean_models, + ) + + # ---- Fit dispersion trend ----# + + dispersion_trend_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_fit_dispersion_trend_and_prior_dispersion, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Fitting dispersion trend", + clean_models=clean_models, + ) + + return local_states, dispersion_trend_shared_state, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py new file mode 100644 index 0000000..6854d5b --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py @@ -0,0 +1,279 @@ +"""Module containing the substeps for the computation of size factors.""" +import warnings + +import anndata as ad +import numpy as np +import pandas as pd +from pydeseq2.default_inference import DefaultInference +from pydeseq2.utils import mean_absolute_deviation +from scipy.special import polygamma # type: ignore +from scipy.stats import trim_mean # type: ignore +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_dispersion_prior.utils import ( # noqa: E501 + disp_function, +) +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +# TODO : This step could be removed now that genewise dispersions are computed in the +# pipeline. This would save an aggregation -> local node -> aggregation node +# communication. +class LocGetMeanDispersionAndMean: + """Mixin to get the local mean and dispersion.""" + + local_adata: ad.AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_mean_and_dispersion( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + # pylint: disable=unused-argument + """Return local gene means and dispersion. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state returned by the last step of gene-wise dispersion computation. + Contains a "genewise_dispersions" key with the gene-wise dispersions. + + Returns + ------- + dict + Local results to be shared via shared_state to the aggregation node. dict + with the following keys: + - mean_normed_counts: np.ndarray[float] of shape (n_genes,) + The mean normed counts. + - n_obs: int + The number of observations. + - non_zero: np.ndarray[bool] of shape (n_genes,) + Mask of the genes with non zero counts. + - genewise_dispersions: np.ndarray[float] of shape (n_genes,) + The genewise dispersions. + - num_vars: int + The number of variables. + + """ + # Save gene-wise dispersions from the previous step. + # Dispersions of all-zero genes should already be NaN. + self.local_adata.varm["genewise_dispersions"] = shared_state[ + "genewise_dispersions" + ] + + # TODO: these could be gathered earlier and sent directly to the aggregation + # node. + return { + "mean_normed_counts": self.local_adata.layers["normed_counts"].mean(0), + "n_obs": self.local_adata.n_obs, + "non_zero": self.local_adata.varm["non_zero"], + "genewise_dispersions": self.local_adata.varm["genewise_dispersions"], + "n_params": self.local_adata.uns["n_params"], + } + + +class AggFitDispersionTrendAndPrior: + """Mixin class to implement the fit of the dispersion trend.""" + + min_disp: float + + @remote + @log_remote + def agg_fit_dispersion_trend_and_prior_dispersion(self, shared_states): + """ + Fit the dispersion trend, and compute the dispersion prior. + + Parameters + ---------- + shared_states : dict + Shared states from the local step with the following keys: + - genewise_dispersions: np.ndarray of shape (n_genes,) + - n_params: int + - non_zero: np.ndarray of shape (n_genes,) + - mean_normed_counts: np.ndarray of shape (n_genes,) + - n_obs: int + + Returns + ------- + dict + dict with the following keys: + - prior_disp_var: float + The prior dispersion variance. + - _squared_logres: float + The squared log-residuals. + - trend_coeffs: np.ndarray of shape (2,) + The coefficients of the parametric dispersion trend. + - fitted_dispersions: np.ndarray of shape (n_genes,) + The fitted dispersions, computed from the dispersion trend. + - disp_function_type: str + The type of dispersion function (parametric or mean). + - mean_disp: float, optional + The mean dispersion (if "mean" fit type). + + """ + genewise_dispersions = shared_states[0]["genewise_dispersions"] + n_params = shared_states[0]["n_params"] + non_zero = shared_states[0]["non_zero"] + n_total_obs = sum([state["n_obs"] for state in shared_states]) + mean_normed_counts = ( + sum( + [ + state["mean_normed_counts"] * state["n_obs"] + for state in shared_states + ] + ) + / n_total_obs + ) + + # Exclude all-zero counts + targets = pd.Series( + genewise_dispersions.copy(), + ) + targets = targets[non_zero] + covariates = pd.Series(1 / mean_normed_counts[non_zero], index=targets.index) + + for gene in targets.index: + if ( + np.isinf(covariates.loc[gene]).any() + or np.isnan(covariates.loc[gene]).any() + ): + targets.drop(labels=[gene], inplace=True) + covariates.drop(labels=[gene], inplace=True) + + # Initialize coefficients + old_coeffs = pd.Series([0.1, 0.1]) + coeffs = pd.Series([1.0, 1.0]) + mean_disp = None + + disp_function_type = "parametric" + while (coeffs > 1e-10).all() and ( + np.log(np.abs(coeffs / old_coeffs)) ** 2 + ).sum() >= 1e-6: + old_coeffs = coeffs + ( + coeffs, + predictions, + converged, + ) = DefaultInference().dispersion_trend_gamma_glm(covariates, targets) + + if not converged or (coeffs <= 1e-10).any(): + warnings.warn( + "The dispersion trend curve fitting did not converge. " + "Switching to a mean-based dispersion trend.", + UserWarning, + stacklevel=2, + ) + mean_disp = trim_mean( + genewise_dispersions[genewise_dispersions > 10 * self.min_disp], + proportiontocut=0.001, + ) + disp_function_type = "mean" + + pred_ratios = genewise_dispersions[covariates.index] / predictions + + targets.drop( + targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, + inplace=True, + ) + covariates.drop( + covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, + inplace=True, + ) + + fitted_dispersions = np.full_like(genewise_dispersions, np.NaN) + + fitted_dispersions[non_zero] = disp_function( + mean_normed_counts[non_zero], + disp_function_type=disp_function_type, + coeffs=coeffs, + mean_disp=mean_disp, + ) + + disp_residuals = np.log(genewise_dispersions[non_zero]) - np.log( + fitted_dispersions[non_zero] + ) + + # Compute squared log-residuals and prior variance based on genes whose + # dispersions are above 100 * min_disp. This is to reproduce DESeq2's behaviour. + above_min_disp = genewise_dispersions[non_zero] >= (100 * self.min_disp) + + _squared_logres = mean_absolute_deviation(disp_residuals[above_min_disp]) ** 2 + + prior_disp_var = np.maximum( + _squared_logres - polygamma(1, (n_total_obs - n_params) / 2), + 0.25, + ) + + return { + "prior_disp_var": prior_disp_var, + "_squared_logres": _squared_logres, + "trend_coeffs": coeffs, + "fitted_dispersions": fitted_dispersions, + "disp_function_type": disp_function_type, + "mean_disp": mean_disp, + } + + +class LocUpdateFittedDispersions: + """Mixin to update the fitted dispersions after replacing outliers. + + To use in refit mode only + """ + + local_adata: ad.AnnData + refit_adata: ad.AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_update_fitted_dispersions( + self, + data_from_opener, + shared_state: dict, + ) -> None: + """ + Update the fitted dispersions after replacing outliers. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + A dictionary with a "fitted_dispersions" key, containing the dispersions + fitted before replacing the outliers. + """ + # Start by updating gene-wise dispersions + self.refit_adata.varm["genewise_dispersions"] = shared_state[ + "genewise_dispersions" + ] + + # Update the fitted dispersions + non_zero = self.refit_adata.varm["non_zero"] + self.refit_adata.uns["disp_function_type"] = self.local_adata.uns[ + "disp_function_type" + ] + + fitted_dispersions = np.full_like( + self.refit_adata.varm["genewise_dispersions"], np.NaN + ) + + fitted_dispersions[non_zero] = disp_function( + self.refit_adata.varm["_normed_means"][non_zero], + disp_function_type=self.refit_adata.uns["disp_function_type"], + coeffs=self.refit_adata.uns["trend_coeffs"], + mean_disp=self.refit_adata.uns["mean_disp"] + if self.refit_adata.uns["disp_function_type"] == "parametric" + else None, + ) + + self.refit_adata.varm["fitted_dispersions"] = fitted_dispersions diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/utils.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/utils.py new file mode 100644 index 0000000..3067617 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/utils.py @@ -0,0 +1,25 @@ +from typing import Union + +import numpy as np +import pandas as pd +from pydeseq2.utils import dispersion_trend + + +def disp_function( + x, + disp_function_type, + coeffs: Union["pd.Series[float]", np.ndarray] | None = None, + mean_disp: float | None = None, +) -> float | np.ndarray: + """Return the dispersion trend function at x.""" + if disp_function_type == "parametric": + assert coeffs is not None, "coeffs must be provided for parametric dispersion." + return dispersion_trend(x, coeffs=coeffs) + elif disp_function_type == "mean": + assert mean_disp is not None, "mean_disp must be provided for mean dispersion." + return np.full_like(x, mean_disp) + else: + raise ValueError( + "disp_function_type must be 'parametric' or 'mean'," + f" got {disp_function_type}" + ) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py new file mode 100644 index 0000000..fcf77d1 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py @@ -0,0 +1,5 @@ +"""Module containing the mixin class to compute genewise dispersions.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_genewise_dispersions import ( # noqa: E501 + ComputeGenewiseDispersions, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/__init__.py new file mode 100644 index 0000000..c9f5236 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/__init__.py @@ -0,0 +1,5 @@ +"""Module to implement the computation of MoM dispersions.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions.compute_MoM_dispersions import ( # noqa: E501 + ComputeMoMDispersions, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_MoM_dispersions.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_MoM_dispersions.py new file mode 100644 index 0000000..a782465 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_MoM_dispersions.py @@ -0,0 +1,124 @@ +"""Main module to compute method of moments (MoM) dispersions.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions.compute_rough_dispersions import ( # noqa: E501 + ComputeRoughDispersions, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions.substeps import ( # noqa: E501 + AggMomentsDispersion, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions.substeps import ( # noqa: E501 + LocInvSizeMean, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeMoMDispersions( + ComputeRoughDispersions, + LocInvSizeMean, + AggMomentsDispersion, +): + """Mixin class to implement the computation of MoM dispersions. + + Relies on the ComputeRoughDispersions class, in addition to substeps. + + Methods + ------- + compute_MoM_dispersions + The method to compute the MoM dispersions, that must be used in the main + pipeline. + + """ + + def compute_MoM_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + gram_features_shared_states, + round_idx, + clean_models, + refit_mode: bool = False, + ): + """Compute method of moments dispersions. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + gram_features_shared_states: list + The list of shared states outputed by the compute_size_factors step. + They contain a "local_gram_matrix" and a "local_features" fields. + + round_idx: int + The current round. + + clean_models: bool + Whether to clean the models after the computation. + + refit_mode: bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, the pipeline will be run on `refit_adata`s instead of + `local_adata`s (default: False). + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + mom_dispersions_shared_state: dict + Shared states containing MoM dispersions. + + round_idx: int + The updated round number. + + """ + ###### Fit rough dispersions ###### + + local_states, shared_states, round_idx = self.compute_rough_dispersions( + train_data_nodes, + aggregation_node, + local_states, + gram_features_shared_states=gram_features_shared_states, + round_idx=round_idx, + clean_models=clean_models, + refit_mode=refit_mode, + ) + + ###### Compute moments dispersions ###### + + # ---- Compute local means for moments dispersions---- # + + local_states, shared_states, round_idx = local_step( + local_method=self.local_inverse_size_mean, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=shared_states, + aggregation_id=aggregation_node.organization_id, + description="Compute local inverse size factor means.", + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + # ---- Compute moments dispersions and merge to get MoM dispersions ---- # + + mom_dispersions_shared_state, round_idx = aggregation_step( + aggregation_method=self.aggregate_moments_dispersions, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Compute global MoM dispersions", + clean_models=clean_models, + ) + + return local_states, mom_dispersions_shared_state, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_rough_dispersions.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_rough_dispersions.py new file mode 100644 index 0000000..87c8bd4 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_rough_dispersions.py @@ -0,0 +1,124 @@ +"""Module to compute rough dispersions.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions.substeps import ( # noqa: E501 + AggCreateRoughDispersionsSystem, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions.substeps import ( # noqa: E501 + AggRoughDispersion, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions.substeps import ( # noqa: E501 + LocRoughDispersion, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeRoughDispersions( + AggRoughDispersion, + LocRoughDispersion, + AggCreateRoughDispersionsSystem, +): + """Mixin class to implement the computation of rough dispersions. + + Methods + ------- + compute_rough_dispersions + The method to compute the rough dispersions, that must be used in the main + pipeline. + + """ + + def compute_rough_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + gram_features_shared_states, + round_idx, + clean_models, + refit_mode: bool = False, + ): + """Compute rough dispersions. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + gram_features_shared_states: list + The list of shared states outputed by the compute_size_factors step. + They contain a "local_gram_matrix" and a "local_features" fields. + + round_idx: int + The current round. + + clean_models: bool + Whether to clean the models after the computation. + + refit_mode: bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, the pipeline will be run on `refit_adata`s instead of + `local_adata`s (default: False). + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + rough_dispersion_shared_state: dict + Shared states containing rough dispersions. + + round_idx: int + The updated round number. + + """ + # TODO: in refit mode, we need to gather the gram matrix and the features some + # way + + # ---- Solve global linear system ---- # + + rough_dispersion_system_shared_state, round_idx = aggregation_step( + aggregation_method=self.create_rough_dispersions_system, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=gram_features_shared_states, + round_idx=round_idx, + description="Solving system for rough dispersions", + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + # ---- Compute local rough dispersions---- # + + local_states, shared_states, round_idx = local_step( + local_method=self.local_rough_dispersions, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=rough_dispersion_system_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Computing local rough dispersions", + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + # ---- Compute global rough dispersions---- # + + rough_dispersion_shared_state, round_idx = aggregation_step( + aggregation_method=self.aggregate_rough_dispersions, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Compute global rough dispersions", + clean_models=clean_models, + ) + + return local_states, rough_dispersion_shared_state, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py new file mode 100644 index 0000000..dace422 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py @@ -0,0 +1,301 @@ +"""Module to implement the substeps for the rough dispersions step. + +This module contains all these substeps as mixin classes. +""" + +import numpy as np +from anndata import AnnData +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.layers.build_layers import set_y_hat +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class AggRoughDispersion: + """Mixin to aggregate local rough dispersions.""" + + @remote + @log_remote + def aggregate_rough_dispersions(self, shared_states): + """Aggregate local rough dispersions. + + Parameters + ---------- + shared_states : list + List of results (rough_dispersions, n_obs, n_params) from training nodes. + + Returns + ------- + dict + Global rough dispersions. + """ + rough_dispersions = sum( + [state["local_rough_dispersions"] for state in shared_states] + ) + + tot_obs = sum([state["local_n_obs"] for state in shared_states]) + n_params = shared_states[0]["local_n_params"] + + if tot_obs <= n_params: + raise ValueError( + "The number of samples is smaller or equal to the number of design " + "variables, i.e., there are no replicates to estimate the " + "dispersions. Please use a design with fewer variables." + ) + + return { + "rough_dispersions": np.maximum(rough_dispersions / (tot_obs - n_params), 0) + } + + +class AggCreateRoughDispersionsSystem: + """Mixin to solve the linear system for rough dispersions.""" + + @remote + @log_remote + def create_rough_dispersions_system(self, shared_states, refit_mode: bool = False): + """Solve the linear system in for rough dispersions. + + Parameters + ---------- + shared_states : list + List of results (local_gram_matrix, local_features) from training nodes. + + refit_mode : bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, there is no need to compute the Gram matrix which was + already computed in the compute_size_factors step (default: False). + + Returns + ------- + dict + The global feature vector and the global hat matrix if refit_mode is + ``False``. + """ + shared_state = { + "global_feature_vector": sum( + [state["local_features"] for state in shared_states] + ) + } + if not refit_mode: + shared_state["global_gram_matrix"] = sum( + [state["local_gram_matrix"] for state in shared_states] + ) + + return shared_state + + +class LocRoughDispersion: + """Mixin to compute local rough dispersions.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_rough_dispersions( + self, data_from_opener, shared_state, refit_mode: bool = False + ) -> dict: + """Compute local rough dispersions, and save the global gram matrix. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state containing + - the gram matrix, if refit_mode is ``False``, + - the global feature vector. + + refit_mode : bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, the pipeline will be run on `refit_adata`s instead of + `local_adata`s (default: False). + + Returns + ------- + dict + Dictionary containing local rough dispersions, number of samples and + number of parameters (i.e. number of columns in the design matrix). + """ + if not refit_mode: + global_gram_matrix = shared_state["global_gram_matrix"] + self.local_adata.uns["_global_gram_matrix"] = global_gram_matrix + else: + global_gram_matrix = self.local_adata.uns["_global_gram_matrix"] + + beta_rough_dispersions = np.linalg.solve( + global_gram_matrix, shared_state["global_feature_vector"] + ) + + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + adata.varm["_beta_rough_dispersions"] = beta_rough_dispersions.T + # Save the rough dispersions beta so that we can reconstruct y_hat + set_y_hat(adata) + + # Save global beta in the local data because so it can be used later in + # fit_lin_mu. Do it before clipping. + + y_hat = np.maximum(adata.layers["_y_hat"], 1) + unnormed_alpha_rde = ( + ((adata.layers["normed_counts"] - y_hat) ** 2 - y_hat) / (y_hat**2) + ).sum(0) + return { + "local_rough_dispersions": unnormed_alpha_rde, + "local_n_obs": adata.n_obs, + "local_n_params": adata.uns["n_params"], + } + + +class LocInvSizeMean: + """Mixin to compute local means of inverse size factors.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_inverse_size_mean( + self, data_from_opener, shared_state=None, refit_mode: bool = False + ) -> dict: + """Compute local means of inverse size factors, counts, and squared counts. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state containing rough dispersions from aggregator. + + refit_mode : bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, the pipeline will be run on `refit_adata`s instead of + `local_adata`s (default: False). + + Returns + ------- + dict + dictionary containing all quantities required to compute MoM dispersions: + local inverse size factor means, counts means, squared counts means, + rough dispersions and number of samples. + """ + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + adata.varm["_rough_dispersions"] = shared_state["rough_dispersions"] + + return { + "local_inverse_size_mean": (1 / adata.obsm["size_factors"]).mean(), + "local_counts_mean": adata.layers["normed_counts"].mean(0), + "local_squared_squared_mean": (adata.layers["normed_counts"] ** 2).mean(0), + "local_n_obs": adata.n_obs, + # Pass rough dispersions to the aggregation node, to compute MoM dispersions + "rough_dispersions": shared_state["rough_dispersions"], + } + + +class AggMomentsDispersion: + """Mixin to compute MoM dispersions.""" + + local_adata: AnnData + max_disp: float + min_disp: float + + @remote + @log_remote + def aggregate_moments_dispersions(self, shared_states): + """Compute global moments dispersions. + + Parameters + ---------- + shared_states : list + List of results (local_inverse_size_mean, local_counts_mean, + local_squared_squared_mean, local_n_obs, rough_dispersions) + from training nodes. + + Returns + ------- + dict + Global moments dispersions, the mask of all zero genes, the total + number of samples (used to set max_disp and lr), and + the total normed counts mean (used in the independent filtering + step). + """ + tot_n_obs = sum([state["local_n_obs"] for state in shared_states]) + + # Compute the mean of inverse size factors + tot_inv_size_mean = ( + sum( + [ + state["local_n_obs"] * state["local_inverse_size_mean"] + for state in shared_states + ] + ) + / tot_n_obs + ) + + # Compute the mean and variance of normalized counts + + tot_counts_mean = ( + sum( + [ + state["local_n_obs"] * state["local_counts_mean"] + for state in shared_states + ] + ) + / tot_n_obs + ) + non_zero = tot_counts_mean != 0 + + tot_squared_mean = ( + sum( + [ + state["local_n_obs"] * state["local_squared_squared_mean"] + for state in shared_states + ] + ) + / tot_n_obs + ) + + counts_variance = ( + tot_n_obs / (tot_n_obs - 1) * (tot_squared_mean - tot_counts_mean**2) + ) + + moments_dispersions = np.zeros( + counts_variance.shape, dtype=counts_variance.dtype + ) + moments_dispersions[non_zero] = ( + counts_variance[non_zero] - tot_inv_size_mean * tot_counts_mean[non_zero] + ) / tot_counts_mean[non_zero] ** 2 + + # Get rough dispersions from the first center + rough_dispersions = shared_states[0]["rough_dispersions"] + + # Compute the maximum dispersion + max_disp = np.maximum(self.max_disp, tot_n_obs) + + # Return moment estimate + alpha_hat = np.minimum(rough_dispersions, moments_dispersions) + MoM_dispersions = np.clip(alpha_hat, self.min_disp, max_disp) + + # Set MoM dispersions of all zero genes to NaN + + MoM_dispersions[~non_zero] = np.nan + return { + "MoM_dispersions": MoM_dispersions, + "non_zero": non_zero, + "tot_num_samples": tot_n_obs, + "tot_counts_mean": tot_counts_mean, + } diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_genewise_dispersions.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_genewise_dispersions.py new file mode 100644 index 0000000..190d4d5 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_genewise_dispersions.py @@ -0,0 +1,192 @@ +"""Main module to compute genewise dispersions.""" +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions import ( # noqa: E501 + ComputeMoMDispersions, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates import ( # noqa: E501 + GetNumReplicates, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.substeps import ( # noqa: E501 + LocLinMu, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.substeps import ( # noqa: E501 + LocSetMuHat, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc import ComputeLFC +from fedpydeseq2.core.fed_algorithms import ComputeDispersionsGridSearch +from fedpydeseq2.core.utils import local_step + + +class ComputeGenewiseDispersions( + ComputeDispersionsGridSearch, + ComputeMoMDispersions, + LocLinMu, + GetNumReplicates, + ComputeLFC, + LocSetMuHat, +): + """ + Mixin class to implement the computation of both genewise and MAP dispersions. + + The switch between genewise and MAP dispersions is done by setting the `fit_mode` + argument in the `fit_dispersions` to either "MLE" or "MAP". + + Methods + ------- + fit_gene_wise_dispersions + A method to fit gene-wise dispersions using a grid search. + Performs four steps: + 1. Compute the first dispersions estimates using a + method of moments (MoM) approach. + 2. Compute the number of replicates for each combination of factors. + This step is necessary to compute the mean estimate in one case, and + in downstream steps (cooks distance, etc). + 3. Compute an estimate of the mean from these dispersions. + 4. Fit the dispersions using a grid search. + + + """ + + def fit_genewise_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + gram_features_shared_states, + round_idx, + clean_models, + refit_mode: bool = False, + ): + """Fit the gene-wise dispersions. + + Performs four steps: + 1. Compute the first dispersions estimates using a + method of moments (MoM) approach. + 2. Compute the number of replicates for each combination of factors. + This step is necessary to compute the mean estimate in one case, and + in downstream steps (cooks distance, etc). + 3. Compute an estimate of the mean from these dispersions. + 4. Fit the dispersions using a grid search. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + gram_features_shared_states: list + The list of shared states outputed by the compute_size_factors step. + They contain a "local_gram_matrix" and a "local_features" fields. + + round_idx: int + The current round. + + clean_models: bool + Whether to clean the models after the computation. + + refit_mode: bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, the pipeline will be run on `refit_adata`s instead of + `local_adata`s. (default: False). + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + shared_state: dict or list[dict] + A dictionary containing: + - "genewise_dispersions": The MLE dispersions, to be stored locally at + - "lower_log_bounds": log lower bounds for the grid search (only used in + internal loop), + - "upper_log_bounds": log upper bounds for the grid search (only used in + internal loop). + + round_idx: int + The updated round index. + """ + # ---- Compute MoM dispersions ---- # + ( + local_states, + mom_dispersions_shared_state, + round_idx, + ) = self.compute_MoM_dispersions( + train_data_nodes, + aggregation_node, + local_states, + gram_features_shared_states, + round_idx, + clean_models, + refit_mode=refit_mode, + ) + + # ---- Compute the initial mu estimates ---- # + + # 1 - Compute the linear mu estimates. + + local_states, linear_shared_states, round_idx = local_step( + local_method=self.fit_lin_mu, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=mom_dispersions_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute local linear mu estimates.", + round_idx=round_idx, + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + # 2 - Compute IRLS estimates. + local_states, round_idx = self.compute_lfc( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + lfc_mode="mu_init", + refit_mode=refit_mode, + ) + + # 3 - Compare the number of replicates to the number of design matrix columns + # and decide whether to use the IRLS estimates or the linear estimates. + + # Compute the number of replicates + local_states, round_idx = self.get_num_replicates( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.set_mu_hat, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Pick between linear and irls mu_hat.", + round_idx=round_idx, + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + # ---- Fit dispersions ---- # + local_states, shared_state, round_idx = self.fit_dispersions( + train_data_nodes, + aggregation_node, + local_states, + shared_state=None, + round_idx=round_idx, + clean_models=clean_models, + fit_mode="MLE", + refit_mode=refit_mode, + ) + + return local_states, shared_state, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/__init__.py new file mode 100644 index 0000000..ca81a82 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/__init__.py @@ -0,0 +1,5 @@ +"""Module containing the mixin class to compute the number of replicates.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates.get_num_replicates import ( # noqa: E501 + GetNumReplicates, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/get_num_replicates.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/get_num_replicates.py new file mode 100644 index 0000000..abf3ee8 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/get_num_replicates.py @@ -0,0 +1,83 @@ +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates.substeps import ( # noqa: E501 + AggGetCountsLvlForCells, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates.substeps import ( # noqa: E501 + LocFinalizeCellCounts, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates.substeps import ( # noqa: E501 + LocGetDesignMatrixLevels, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class GetNumReplicates( + LocGetDesignMatrixLevels, AggGetCountsLvlForCells, LocFinalizeCellCounts +): + """Mixin class to get the number of replicates for each combination of factors.""" + + def get_num_replicates( + self, train_data_nodes, aggregation_node, local_states, round_idx, clean_models + ): + """ + Compute the number of replicates for each combination of factors. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states, to store the number of replicates and cell level codes. + + round_idx: int + The updated round index. + """ + local_states, shared_states, round_idx = local_step( + local_method=self.loc_get_design_matrix_levels, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get local matrix design level", + round_idx=round_idx, + clean_models=clean_models, + ) + counts_lvl_share_state, round_idx = aggregation_step( + aggregation_method=self.agg_get_counts_lvl_for_cells, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Compute counts level", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, _, round_idx = local_step( + local_method=self.loc_finalize_cell_counts, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=counts_lvl_share_state, + aggregation_id=aggregation_node.organization_id, + description="Finalize cell counts", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py new file mode 100644 index 0000000..1ab0675 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py @@ -0,0 +1,109 @@ +import numpy as np +import pandas as pd +from anndata import AnnData +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocGetDesignMatrixLevels: + """Mixin to get the unique values of the local design matrix.""" + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_get_design_matrix_levels(self, data_from_opener, shared_state=dict) -> dict: + """ + Get the values of the local design matrix. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + shared_state : dict + Not used. + + Returns + ------- + dict + Dictionary with the following key: + - unique_counts: unique values and counts of the local design matrix + + """ + unique_counts = self.local_adata.obsm["design_matrix"].value_counts() + + return {"unique_counts": unique_counts} + + +class AggGetCountsLvlForCells: + """Mixin that aggregate the counts of the design matrix values.""" + + @remote + @log_remote + def agg_get_counts_lvl_for_cells(self, shared_states: list[dict]) -> dict: + """ + Aggregate the counts of the design matrix values. + + Parameters + ---------- + shared_states : list(dict) + List of shared states with the following key: + - unique_counts: unique values and counts of the local design matrix + + Returns + ------- + dict + Dictionary with keys labeling the different values taken by the + overall design matrix. Each values of the dictionary contains the + sum of the counts of the corresponding design matrix value and the level. + """ + concat_unique_cont = pd.concat( + [shared_state["unique_counts"] for shared_state in shared_states], axis=1 + ) + counts_by_lvl = concat_unique_cont.fillna(0).sum(axis=1).astype(int) + + return {"counts_by_lvl": counts_by_lvl} + + +class LocFinalizeCellCounts: + """Mixin that finalize the cell counts.""" + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_finalize_cell_counts(self, data_from_opener, shared_state=dict) -> None: + """ + Finalize the cell counts. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Dictionary with keys labeling the different values taken by the + overall design matrix. Each values of the dictionary contains the + sum of the counts of the corresponding design matrix value and the level. + + """ + counts_by_lvl = shared_state["counts_by_lvl"] + + # In order to keep the same objects 'num_replicates' and 'cells' used in + # PyDESeq2, we provide names (0, 1, 2...) to the possible values of the + # design matrix, called "lvl". + # The index of 'num_replicates' is the lvl names (0,1,2...) and its values + # the counts of these lvl + # 'cells' index is the index of the cells in the adata and its values the lvl + # name (0,1,2..) of the cell. + self.local_adata.uns["num_replicates"] = pd.Series(counts_by_lvl.values) + self.local_adata.obs["cells"] = [ + np.argwhere(counts_by_lvl.index == tuple(design))[0, 0] + for design in self.local_adata.obsm["design_matrix"].values + ] diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py new file mode 100644 index 0000000..81ec3d9 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py @@ -0,0 +1,110 @@ +from anndata import AnnData +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.layers.build_layers import set_fit_lin_mu_hat +from fedpydeseq2.core.utils.layers.build_layers import set_mu_hat_layer +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocLinMu: + """Mixin to fit linear mu estimates locally.""" + + local_adata: AnnData + refit_adata: AnnData + min_mu: float + max_disp: float + + @remote_data + @log_remote_data + @reconstruct_adatas + def fit_lin_mu( + self, data_from_opener, shared_state, min_mu=0.5, refit_mode: bool = False + ): + """Fit linear mu estimates and store them locally. + + Parameters + ---------- + data_from_opener : ad.AnnData + Not used. + + shared_state : dict + Contains values to be saved in local adata: + - "MoM_dispersions": MoM dispersions, + - "nom_zero": Mask of all zero genes, + - "tot_num_samples": Total number of samples. + + min_mu : float + Lower threshold for fitted means, for numerical stability. + (default: ``0.5``). + + refit_mode : bool + Whether to run the pipeline in refit mode. If True, the pipeline will be run + on `refit_adata`s instead of `local_adata`s. (default: ``False``). + + """ + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + # save MoM dispersions computed in the previous step + adata.varm["_MoM_dispersions"] = shared_state["MoM_dispersions"] + + # save mask of all zero genes. + # TODO: check that we should also do this in refit mode + adata.varm["non_zero"] = shared_state["non_zero"] + + if not refit_mode: # In refit mode, those are unchanged + # save the total number of samples + self.local_adata.uns["tot_num_samples"] = shared_state["tot_num_samples"] + + # use it to set max_disp + self.local_adata.uns["max_disp"] = max( + self.max_disp, self.local_adata.uns["tot_num_samples"] + ) + + # save the base_mean for independent filtering + adata.varm["_normed_means"] = shared_state["tot_counts_mean"] + + # compute mu_hat + set_fit_lin_mu_hat(adata, min_mu=min_mu) + + +class LocSetMuHat: + """Mixin to set mu estimates locally.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_mu_hat( + self, + data_from_opener, + shared_state, + refit_mode: bool = False, + ) -> None: + """Pick between linear and IRLS mu estimates. + + Parameters + ---------- + data_from_opener : ad.AnnData + Not used. + + shared_state : dict + Not used. + + refit_mode : bool + Whether to run on `refit_adata`s instead of `local_adata`s. + (default: ``False``). + """ + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + # TODO make sure that the adata has the num_replicates and the n_params + set_mu_hat_layer(adata) + del adata.layers["_fit_lin_mu_hat"] + del adata.layers["_irls_mu_hat"] diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py new file mode 100644 index 0000000..0175290 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py @@ -0,0 +1,5 @@ +"""Module which contains the Mixin in charge of fitting log fold changes.""" + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc.compute_lfc import ( # noqa: E501 + ComputeLFC, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc.py new file mode 100644 index 0000000..2afcf9f --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc.py @@ -0,0 +1,172 @@ +"""Module containing the ComputeLFC method.""" +from typing import Literal + +from substrafl.nodes import AggregationNode + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc.substeps import ( + AggCreateBetaInit, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc.substeps import ( + LocGetGramMatrixAndLogFeatures, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc.substeps import ( + LocSaveLFC, +) +from fedpydeseq2.core.fed_algorithms import FedIRLS +from fedpydeseq2.core.fed_algorithms import FedProxQuasiNewton +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeLFC( + LocGetGramMatrixAndLogFeatures, + AggCreateBetaInit, + LocSaveLFC, + FedProxQuasiNewton, + FedIRLS, +): + r"""Mixin class to implement the LFC computation algorithm. + + The goal of this class is to implement the IRLS algorithm specifically applied + to the negative binomial distribution, with fixed dispersion parameter, and + in the case where it fails, to catch it with the FedProxQuasiNewton algorithm. + + This class also initializes the beta parameters and computes the final hat matrix. + + Methods + ------- + compute_lfc + The main method to compute the log fold changes by + running the IRLS algorithm and catching it with the + FedProxQuasiNewton algorithm. + + + """ + + def compute_lfc( + self, + train_data_nodes: list, + aggregation_node: AggregationNode, + local_states: dict, + round_idx: int, + clean_models: bool = True, + lfc_mode: Literal["lfc", "mu_init"] = "lfc", + refit_mode: bool = False, + ): + """Compute the log fold changes. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The current round. + + clean_models: bool + If True, the models are cleaned. + + lfc_mode: Literal["lfc", "mu_init"] + The mode of the IRLS algorithm ("lfc" or "mu_init"). + + refit_mode: bool + Whether to run the pipeline in refit mode, after cooks outliers were + replaced. If True, the pipeline will be run on `refit_adata`s instead of + `local_adata`s. (default: False). + + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The updated round index. + + """ + #### ---- Initialization ---- #### + + # ---- Compute initial local beta estimates ---- # + + local_states, local_beta_init_shared_states, round_idx = local_step( + local_method=self.get_gram_matrix_and_log_features, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Create local initialization beta.", + clean_models=clean_models, + method_params={ + "lfc_mode": lfc_mode, + "refit_mode": refit_mode, + }, + ) + + # ---- Compute initial global beta estimates ---- # + + global_irls_summands_nlls_shared_state, round_idx = aggregation_step( + aggregation_method=self.create_beta_init, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=local_beta_init_shared_states, + description="Create initialization beta paramater.", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### ---- Run IRLS ---- ##### + ( + local_states, + irls_result_shared_state, + round_idx, + ) = self.run_fed_irls( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + input_shared_state=global_irls_summands_nlls_shared_state, + round_idx=round_idx, + clean_models=clean_models, + refit_mode=refit_mode, + ) + + #### ---- Catch with FedProxQuasiNewton ----#### + + local_states, PQN_shared_state, round_idx = self.run_fed_PQN( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + PQN_shared_state=irls_result_shared_state, + first_iteration_mode="irls_catch", + round_idx=round_idx, + clean_models=clean_models, + refit_mode=refit_mode, + ) + + # ---- Compute final hat matrix summands ---- # + + ( + local_states, + _, + round_idx, + ) = local_step( + local_method=self.save_lfc_to_local, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=PQN_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute local hat matrix summands and last nll.", + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + return local_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py new file mode 100644 index 0000000..2b391cf --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py @@ -0,0 +1,400 @@ +"""Module to implement the substeps for the fitting of log fold changes. + +This module contains all these substeps as mixin classes. +""" + +from typing import Any +from typing import Literal + +import numpy as np +import pandas as pd +from anndata import AnnData +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.layers.utils import set_mu_layer +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocGetGramMatrixAndLogFeatures: + """Mixin accessing the quantities to compute the initial beta of ComputeLFC. + + Attributes + ---------- + local_adata : AnnData + The local AnnData object. + + Methods + ------- + get_gram_matrix_and_log_features + A remote_data method. Creates the local quantities necessary + to compute the initial beta. + If the gram matrix is full rank, it shares the features vector + and the gram matrix. If the gram matrix is not full rank, it shares + the normed log means and the number of observations. + + """ + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_gram_matrix_and_log_features( + self, + data_from_opener: AnnData, + shared_state: dict[str, Any], + lfc_mode: Literal["lfc", "mu_init"], + refit_mode: bool = False, + ): + """Create the local quantities necessary to compute the initial beta. + + To do so, we assume that the local_adata.uns contains the following fields: + - n_params: int + The number of parameters. + - _global_gram_matrix: ndarray + The global gram matrix. + + From the IRLS mode, we will set the following fields: + - _irls_mu_param_name: str + The name of the mu parameter, to save at the end of the IRLS run + This is None if we do not want to save the mu parameter. + - _irls_beta_param_name: str + The name of the beta parameter, to save as a varm at the end of the + fed irls run + This is None if we do not want to save the beta parameter. + - _irls_disp_param_name: str + The name of the dispersion parameter. + - _lfc_mode: str + The mode of the IRLS algorithm. This is used to set the previous fields. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + Not used, all the necessary info is stored in the local adata. + + lfc_mode : Literal["lfc", "mu_init"] + The mode of the IRLS algorithm ("lfc", or "mu_init"). + + refit_mode : bool + Whether to run the pipeline on `refit_adata` instead of `local_adata`. + + Returns + ------- + dict + The state to share to the server. + It always contains the following fields: + - gram_full_rank: bool + Whether the gram matrix is full rank. + - n_non_zero_genes: int + The number of non zero genes. + - n_params: int + The number of parameters. + - If the gram matrix is full rank, the state contains: + - local_log_features: ndarray + The local log features. + - global_gram_matrix: ndarray + The global gram matrix. + - If the gram matrix is not full rank, the state contains: + - normed_log_means: ndarray + The normed log means. + - n_obs: int + The number of observations. + + """ + global_gram_matrix = self.local_adata.uns["_global_gram_matrix"] + + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + # Elements to pass on to the next steps of the method + if lfc_mode == "lfc": + adata.uns["_irls_mu_param_name"] = "_mu_LFC" + adata.uns["_irls_beta_param_name"] = "LFC" + adata.uns["_irls_disp_param_name"] = "dispersions" + adata.uns["_lfc_mode"] = "lfc" + elif lfc_mode == "mu_init": + adata.uns["_irls_mu_param_name"] = "_irls_mu_hat" + adata.uns["_irls_beta_param_name"] = "_mu_hat_LFC" + adata.uns["_irls_disp_param_name"] = "_MoM_dispersions" + adata.uns["_lfc_mode"] = "mu_init" + + else: + raise NotImplementedError( + f"Only 'lfc' and 'mu_init' irls modes are supported, got {lfc_mode}." + ) + + # Get non zero genes + non_zero_genes_names = adata.var_names[adata.varm["non_zero"]] + + # See if gram matrix is full rank + gram_full_rank = ( + np.linalg.matrix_rank(global_gram_matrix) == adata.uns["n_params"] + ) + # If the gram matrix is full rank, share the features vector and the gram + # matrix + + shared_state = { + "gram_full_rank": gram_full_rank, + "n_non_zero_genes": len(non_zero_genes_names), + } + + if gram_full_rank: + # Make log features + design = adata.obsm["design_matrix"].values + log_counts = np.log( + adata[:, non_zero_genes_names].layers["normed_counts"] + 0.1 + ) + log_features = (design.T @ log_counts).T + shared_state.update( + { + "local_log_features": log_features, + "global_gram_matrix": global_gram_matrix, + } + ) + else: + # TODO: check that this is correctly recomputed in refit mode + if "normed_log_means" not in adata.varm: + with np.errstate(divide="ignore"): # ignore division by zero warnings + log_counts = np.log(adata.layers["normed_counts"]) + adata.varm["normed_log_means"] = log_counts.mean(0) + normed_log_means = adata.varm["normed_log_means"] + n_obs = adata.n_obs + shared_state.update({"normed_log_means": normed_log_means, "n_obs": n_obs}) + return shared_state + + +class AggCreateBetaInit: + """Mixin to create the beta init. + + Methods + ------- + create_beta_init + A remote method. Creates the beta init (initialization value for the + ComputeLFC algorithm) and returns the initialization state for the + IRLS algorithm containing this initialization value and + other necessary quantities. + """ + + @remote + @log_remote + def create_beta_init(self, shared_states: list[dict]) -> dict[str, Any]: + """Create the beta init. + + It does so either by solving the least squares regression system if + the gram matrix is full rank, or by aggregating the log means if the + gram matrix is not full rank. + + Parameters + ---------- + shared_states: list[dict] + A list of dictionaries containing the following + keys: + - gram_full_rank: bool + Whether the gram matrix is full rank. + - n_non_zero_genes: int + The number of non zero genes. + - n_params: int + The number of parameters. + If the gram matrix is full rank, the state contains: + - local_log_features: ndarray + The local log features, only if the gram matrix is full rank. + - global_gram_matrix: ndarray + The global gram matrix, only if the gram matrix is full rank. + If the gram matrix is not full rank, the state contains: + - normed_log_means: ndarray + The normed log means, only if the gram matrix is not full rank. + - n_obs: int + The number of observations, only if the gram matrix is not full rank. + + + Returns + ------- + dict[str, Any] + A dictionary containing all the necessary info to run IRLS. + It contains the following fields: + - beta: ndarray + The initial beta, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). Is set to False initially, and will + be set to True if the gene has diverged. + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). Is set to True initially, and will be + set to False if the gene has converged or diverged. + - global_nll: ndarray + The global_nll of the current beta from the previous beta, of shape + (n_non_zero_genes,). + - round_number_irls: int + The current round number of the IRLS algorithm. + + + """ + # Get the global quantities + gram_full_rank = shared_states[0]["gram_full_rank"] + n_non_zero_genes = shared_states[0]["n_non_zero_genes"] + + # Step 1: Get the beta init + # Condition on whether or not the gram matrix is full rank + if gram_full_rank: + # Get global gram matrix + global_gram_matrix = shared_states[0]["global_gram_matrix"] + + # Aggregate the feature vectors + feature_vectors = sum( + [state["local_log_features"] for state in shared_states] + ) + + # Solve the system + beta_init = np.linalg.solve(global_gram_matrix, feature_vectors.T).T + + else: + # Aggregate the log means + tot_counts = sum([state["n_obs"] for state in shared_states]) + beta_init = ( + sum( + [ + state["normed_log_means"] * state["n_obs"] + for state in shared_states + ] + ) + / tot_counts + ) + + # Step 2: instantiate other necessary quantities + irls_diverged_mask = np.full(n_non_zero_genes, False) + irls_mask = np.full(n_non_zero_genes, True) + global_nll = np.full(n_non_zero_genes, 1000.0) + + return { + "beta": beta_init, + "irls_diverged_mask": irls_diverged_mask, + "irls_mask": irls_mask, + "global_nll": global_nll, + "round_number_irls": 0, + } + + +class LocSaveLFC: + """Mixin to create the local quantities to compute the final hat matrix. + + Attributes + ---------- + local_adata : AnnData + The local AnnData object. + num_jobs : int + The number of cpus to use. + joblib_verbosity : int + The verbosity of the joblib backend. + joblib_backend : str + The backend to use for the joblib parallelization. + irls_batch_size : int + The batch size to use for the IRLS algorithm. + min_mu : float + The minimum value for the mu parameter. + + Methods + ------- + make_local_final_hat_matrix_summands + A remote_data method. Creates the local quantities to compute the + final hat matrix, which must be computed on all genes. This step + is expected to be applied after catching the IRLS method + with the fed prox quasi newton method, and takes as an input a + shared state from the last iteration of that method. + + """ + + local_adata: AnnData + refit_adata: AnnData + num_jobs: int + joblib_verbosity: int + joblib_backend: str + irls_batch_size: int + min_mu: float + irls_num_iter: int + + @remote_data + @log_remote_data + @reconstruct_adatas + def save_lfc_to_local( + self, + data_from_opener: AnnData, + shared_state: dict[str, Any], + refit_mode: bool = False, + ): + """Create the local quantities to compute the final hat matrix. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + The shared state. + The shared state is a dictionary containing the following + keys: + - beta: ndarray + The current beta, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if the irsl method has diverged. + In that case, these genes are caught with the fed prox newton + method. + (shape: (n_non_zero_genes,)). + - PQN_diverged_mask: ndarray + A boolean mask indicating if the fed prox newton method has + diverged. These genes are not caught by any method, and the + returned beta value is the output of the PQN method, even + though it has not converged. + + refit_mode : bool + Whether to run the pipeline on `refit_adata` instead of `local_adata`. + (default: False). + + """ + beta = shared_state["beta"] + + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + # TODO keeping this in memory for now, see if need for removal at the end + adata.uns["_irls_diverged_mask"] = shared_state["irls_diverged_mask"] + adata.uns["_PQN_diverged_mask"] = shared_state["PQN_diverged_mask"] + + # Get the param names stored in the local adata + mu_param_name = adata.uns["_irls_mu_param_name"] + beta_param_name = adata.uns["_irls_beta_param_name"] + # ---- Step 2: Store the mu, the diagonal of the hat matrix ---- # + # ---- and beta in the adata ---- # + + design_column_names = adata.obsm["design_matrix"].columns + + non_zero_genes_names = adata.var_names[adata.varm["non_zero"]] + + beta_dataframe = pd.DataFrame( + np.NaN, index=adata.var_names, columns=design_column_names + ) + beta_dataframe.loc[non_zero_genes_names, :] = beta + + adata.varm[beta_param_name] = beta_dataframe + + if mu_param_name is not None: + set_mu_layer( + local_adata=adata, + lfc_param_name=beta_param_name, + mu_param_name=mu_param_name, + n_jobs=self.num_jobs, + joblib_verbosity=self.joblib_verbosity, + joblib_backend=self.joblib_backend, + batch_size=self.irls_batch_size, + ) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/utils.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/utils.py new file mode 100644 index 0000000..9629c44 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/utils.py @@ -0,0 +1,58 @@ +"""Module to implement the utilities of the IRLS algorithm. + +Most of these functions have the _batch suffix, which means that they are +vectorized to work over batches of genes in the parralel_backend file in +the same module. +""" + + +import numpy as np + +from fedpydeseq2.core.utils.negative_binomial import grid_nb_nll + + +def make_irls_nll_batch( + beta: np.ndarray, + design_matrix: np.ndarray, + size_factors: np.ndarray, + dispersions: np.ndarray, + counts: np.ndarray, + min_mu: float = 0.5, +) -> np.ndarray: + """ + Compute the negative binomial log likelihood from LFC estimates. + + Used in ComputeLFC to compute the deviance score. This function is vectorized to + work over batches of genes. + + Parameters + ---------- + beta : np.ndarray + Current LFC estimate, of shape (batch_size, n_params). + design_matrix : np.ndarray + The design matrix, of shape (n_obs, n_params). + size_factors : np.ndarray + The size factors, of shape (n_obs). + dispersions : np.ndarray + The dispersions, of shape (batch_size). + counts : np.ndarray + The counts, of shape (n_obs,batch_size). + min_mu : float + Lower bound on estimated means, to ensure numerical stability. + (default: ``0.5``). + + Returns + ------- + np.ndarray + Local negative binomial log-likelihoods, of shape + (batch_size). + """ + mu = np.maximum( + size_factors[:, None] * np.exp(design_matrix @ beta.T), + min_mu, + ) + return grid_nb_nll( + counts, + mu, + dispersions, + ) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/deseq2_lfc_dispersions.py b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/deseq2_lfc_dispersions.py new file mode 100644 index 0000000..d3edb38 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/deseq2_lfc_dispersions.py @@ -0,0 +1,173 @@ +from loguru import logger + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_dispersion_prior import ( # noqa: E501 + ComputeDispersionPrior, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions import ( # noqa: E501 + ComputeGenewiseDispersions, +) +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc import ComputeLFC +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_MAP_dispersions import ( # noqa: E501 + ComputeMAPDispersions, +) +from fedpydeseq2.core.utils import local_step + + +class DESeq2LFCDispersions( + ComputeGenewiseDispersions, + ComputeDispersionPrior, + ComputeMAPDispersions, + ComputeLFC, +): + """Mixin class to compute the log fold change and the dispersions with DESeq2. + + This class encapsulates the steps to compute the log fold change and the + dispersions from a given count matrix and a design matrix. + + Methods + ------- + run_deseq2_lfc_dispersions + The method to compute the log fold change and the dispersions. + It starts from the design matrix and the count matrix. + It returns the shared states by the local nodes after the computation of Cook's + distances. + It is meant to be run two times in the main pipeline if Cook's refitting + is applied/ + """ + + def run_deseq2_lfc_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + gram_features_shared_states, + round_idx, + clean_models, + refit_mode=False, + ): + """ + Run the DESeq2 pipeline to compute the log fold change and the dispersions. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: list[dict] + Local states. Required to propagate intermediate results. + + gram_features_shared_states: list[dict] + Output of the "compute_size_factor step" if refit_mode is False. + Output of the "replace_outliers" step if refit_mode is True. + In both cases, contains a "local_features" key with the features vector + to input to compute_genewise_dispersion. + In the non refit mode case, it also contains a "local_gram_matrix" key + with the local gram matrix. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + refit_mode: bool + Whether we are refittinh Cooks outliers or not. + + + Returns + ------- + local_states: dict + Local states updated with the results of the DESeq2 pipeline. + + round_idx: int + The updated round index. + + """ + #### Fit genewise dispersions #### + + # Note : for optimization purposes, we could avoid two successive local + # steps here, at the cost of a more complex initialization of the + # fit_dispersions method. + logger.info("Fit genewise dispersions...") + ( + local_states, + genewise_dispersions_shared_state, + round_idx, + ) = self.fit_genewise_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + gram_features_shared_states=gram_features_shared_states, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + refit_mode=refit_mode, + ) + logger.info("Finished fitting genewise dispersions.") + + if not refit_mode: + #### Fit dispersion trends #### + logger.info("Compute dispersion prior...") + ( + local_states, + dispersion_trend_share_state, + round_idx, + ) = self.compute_dispersion_prior( + train_data_nodes, + aggregation_node, + local_states, + genewise_dispersions_shared_state, + round_idx, + clean_models, + ) + logger.info("Finished computing dispersion prior.") + else: + # Just update the fitted dispersions + ( + local_states, + dispersion_trend_share_state, + round_idx, + ) = local_step( + local_method=self.loc_update_fitted_dispersions, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=genewise_dispersions_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Update fitted dispersions", + clean_models=clean_models, + ) + + #### Fit MAP dispersions #### + logger.info("Fit MAP dispersions...") + ( + local_states, + round_idx, + ) = self.fit_MAP_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + shared_state=dispersion_trend_share_state if not refit_mode else None, + round_idx=round_idx, + clean_models=clean_models, + refit_mode=refit_mode, + ) + logger.info("Finished fitting MAP dispersions.") + + #### Compute log fold changes #### + logger.info("Compute log fold changes...") + local_states, round_idx = self.compute_lfc( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=True, + lfc_mode="lfc", + refit_mode=refit_mode, + ) + logger.info("Finished computing log fold changes.") + + return local_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/__init__.py new file mode 100644 index 0000000..0cc8967 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/__init__.py @@ -0,0 +1,3 @@ +"""Module containing all the necessary steps to perform statistical analysis.""" + +from fedpydeseq2.core.deseq2_core.deseq2_stats.deseq2_stats import DESeq2Stats diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/__init__.py new file mode 100644 index 0000000..93dd403 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/__init__.py @@ -0,0 +1,5 @@ +"""Module containing the Mixin to compute adjusted p-values.""" + +from fedpydeseq2.core.deseq2_core.deseq2_stats.compute_padj.compute_padj import ( + ComputeAdjustedPValues, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/compute_padj.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/compute_padj.py new file mode 100644 index 0000000..51f4263 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/compute_padj.py @@ -0,0 +1,95 @@ +from fedpydeseq2.core.deseq2_core.deseq2_stats.compute_padj.substeps import ( + IndependentFiltering, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.compute_padj.substeps import ( + PValueAdjustment, +) +from fedpydeseq2.core.utils import local_step + + +class ComputeAdjustedPValues(IndependentFiltering, PValueAdjustment): + """Mixin class to implement the computation of adjusted p-values. + + Attributes + ---------- + independent_filter: bool + A boolean flag to indicate whether to use independent filtering or not. + + Methods + ------- + compute_adjusted_p_values + A method to compute adjusted p-values. + Runs independent filtering if self.independent_filter is True. + Runs BH method otherwise. + + """ + + independent_filter: bool = False + + def compute_adjusted_p_values( + self, + train_data_nodes, + aggregation_node, + local_states, + wald_test_shared_state, + round_idx, + clean_models, + ): + """Compute adjusted p-values. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + wald_test_shared_state: dict + Shared states containing the Wald test results. + + round_idx: int + The current round. + + clean_models: bool + If True, the models are cleaned. + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + + round_idx: int + The updated round index. + + """ + if self.independent_filter: + local_states, _, round_idx = local_step( + local_method=self.run_independent_filtering, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=wald_test_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute adjusted P values using independent filtering.", + clean_models=clean_models, + ) + else: + local_states, _, round_idx = local_step( + local_method=self.run_p_value_adjustment, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=wald_test_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute adjusted P values using BH method.", + clean_models=clean_models, + ) + + return local_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/substeps.py new file mode 100644 index 0000000..96b7bfd --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/compute_padj/substeps.py @@ -0,0 +1,148 @@ +from typing import Any + +import numpy as np +import pandas as pd +from anndata import AnnData +from pydeseq2.utils import lowess +from scipy.stats import false_discovery_control +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data + + +class IndependentFiltering: + """Mixin class implementing independent filtering. + + Attributes + ---------- + local_adata : AnnData + Local AnnData object. + + alpha : float + Significance level. + + Methods + ------- + run_independent_filtering + Run independent filtering on the p-values trend + """ + + local_adata: AnnData + alpha: float + + @remote_data + @log_remote_data + @reconstruct_adatas + def run_independent_filtering(self, data_from_opener, shared_state: Any): + """Run independent filtering on the p-values trend. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + Shared state containing the results of the wald tests, namely + - "p_values" : p-values + - "wald_statistics" : Wald statistics + - "wald_se" : Wald standard errors + + """ + p_values = shared_state["p_values"] + wald_statistics = shared_state["wald_statistics"] + wald_se = shared_state["wald_se"] + + self.local_adata.varm["p_values"] = p_values + self.local_adata.varm["wald_statistics"] = wald_statistics + self.local_adata.varm["wald_se"] = wald_se + + base_mean = self.local_adata.varm["_normed_means"] + + lower_quantile = np.mean(base_mean == 0) + + if lower_quantile < 0.95: + upper_quantile = 0.95 + else: + upper_quantile = 1 + + theta = np.linspace(lower_quantile, upper_quantile, 50) + cutoffs = np.quantile(base_mean, theta) + + result = pd.DataFrame( + np.nan, index=self.local_adata.var_names, columns=np.arange(len(theta)) + ) + + for i, cutoff in enumerate(cutoffs): + use = (base_mean >= cutoff) & (~np.isnan(p_values)) + U2 = p_values[use] + if not len(U2) == 0: + result.loc[use, i] = false_discovery_control(U2, method="bh") + + num_rej = (result < self.alpha).sum(0) + lowess_res = lowess(theta, num_rej, frac=1 / 5) + + if num_rej.max() <= 10: + j = 0 + else: + residual = num_rej[num_rej > 0] - lowess_res[num_rej > 0] + thresh = lowess_res.max() - np.sqrt(np.mean(residual**2)) + + if np.any(num_rej > thresh): + j = np.where(num_rej > thresh)[0][0] + else: + j = 0 + + self.local_adata.varm["padj"] = result.loc[:, j] + + +class PValueAdjustment: + """Mixin class implementing p-value adjustment. + + Attributes + ---------- + local_adata : AnnData + Local AnnData object. + + Methods + ------- + run_p_value_adjustment + Run p-value adjustment on the p-values trend using the Benjamini-Hochberg + method. + + """ + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def run_p_value_adjustment(self, data_from_opener, shared_state: Any): + """Run p-value adjustment on the p-values trend using the BH method. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + Shared state containing the results of the Wald tests, namely + - "p_values" : p-values, as a numpy array + - "wald_statistics" : Wald statistics + - "wald_se" : Wald standard errors + + """ + p_values = shared_state["p_values"] + wald_statistics = shared_state["wald_statistics"] + wald_se = shared_state["wald_se"] + + self.local_adata.varm["p_values"] = p_values + self.local_adata.varm["wald_statistics"] = wald_statistics + self.local_adata.varm["wald_se"] = wald_se + + padj = pd.Series(np.nan, index=self.local_adata.var_names) + padj.loc[~np.isnan(p_values)] = false_discovery_control( + p_values[~np.isnan(p_values)], method="bh" + ) + + self.local_adata.varm["padj"] = padj diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/__init__.py new file mode 100644 index 0000000..5afdc51 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/__init__.py @@ -0,0 +1,5 @@ +"""Substep to perform cooks filtering.""" + +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.cooks_filtering import ( + CooksFiltering, +) diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/cooks_filtering.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/cooks_filtering.py new file mode 100644 index 0000000..246c2b8 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/cooks_filtering.py @@ -0,0 +1,187 @@ +"""Module to implement the base Mixin class for Cooks filtering.""" +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + AggCooksFiltering, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + AggMaxCooks, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + AggMaxCooksCounts, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + AggregateCooksOutliers, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + LocCountNumberSamplesAbove, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + LocFindCooksOutliers, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + LocGetMaxCooks, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering.substeps import ( + LocGetMaxCooksCounts, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class CooksFiltering( + LocFindCooksOutliers, + AggregateCooksOutliers, + LocGetMaxCooks, + AggMaxCooks, + LocGetMaxCooksCounts, + AggMaxCooksCounts, + LocCountNumberSamplesAbove, + AggCooksFiltering, +): + """A class to perform Cooks filtering of p-values. + + Methods + ------- + cooks_filtering + The method to find Cooks outliers. + """ + + def cooks_filtering( + self, + train_data_nodes, + aggregation_node, + local_states, + wald_test_shared_state, + round_idx, + clean_models, + ): + """Perform Cooks filtering. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: list[dict] + Local states. Required to propagate intermediate results. + + wald_test_shared_state : dict + A shared state containing the Wald test results. + These results are the following fields: + - "p_values": p-values of the Wald test. + - "wald_statistics" : Wald statistics. + - "wald_se" : Wald standard errors. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states. The new local state contains Cook's distances. + + shared_state: dict + A new shared state containing the following fields: + - "p_values": p-values of the Wald test, updated to be nan for Cook's + outliers. + - "wald_statistics" : Wald statistics, for compatibility. + - "wald_se" : Wald standard errors, for compatibility. + + round_idx: int + The updated round index. + + """ + local_states, shared_states, round_idx = local_step( + local_method=self.find_local_cooks_outliers, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=wald_test_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Find local Cook's outliers", + round_idx=round_idx, + clean_models=clean_models, + ) + + cooks_outliers_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_cooks_outliers, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Find the global Cook's outliers", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.get_max_local_cooks, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=cooks_outliers_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Get local max cooks distance", + round_idx=round_idx, + clean_models=clean_models, + ) + + max_cooks_distance_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_max_cooks, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Get the max cooks distance for the outliers", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.get_max_local_cooks_gene_counts, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=max_cooks_distance_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Get the local max gene counts for the Cook's outliers", + round_idx=round_idx, + clean_models=clean_models, + ) + + max_cooks_gene_counts_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_max_cooks_gene_counts, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Get the max gene counts for the Cook's outliers", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.count_local_number_samples_above, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=max_cooks_gene_counts_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Count the number of samples above the max gene counts", + round_idx=round_idx, + clean_models=clean_models, + ) + + cooks_filtered_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_cooks_filtering, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Finish Cooks filtering", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, cooks_filtered_shared_state, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/substeps.py new file mode 100644 index 0000000..b922fcb --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/cooks_filtering/substeps.py @@ -0,0 +1,533 @@ +import numpy as np +import pandas as pd +from anndata import AnnData +from scipy.stats import f # type: ignore +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import prepare_cooks_agg +from fedpydeseq2.core.utils.layers import prepare_cooks_local +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocFindCooksOutliers: + """Mixin class to find the local cooks outliers. + + Attributes + ---------- + local_adata : AnnData + Local AnnData object. + Is expected to have a "tot_num_samples" key in uns. + + refit_cooks : bool + Whether to refit the cooks outliers. + + + Methods + ------- + find_local_cooks_outliers + Find the local cooks outliers. + + """ + + local_adata: AnnData + refit_cooks: bool + + @remote_data + @log_remote_data + @reconstruct_adatas + @prepare_cooks_local + def find_local_cooks_outliers( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Find the local cooks outliers. + + This method is expected to run on the results of the Wald tests. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + Shared state from the previous step with the following + keys: + - p_values: np.ndarray of shape (n_genes,) + - wald_statistics: np.ndarray of shape (n_genes,) + - wald_se: np.ndarray of shape (n_genes,) + + Returns + ------- + shared_state : dict + A shared state with the following fields: + - local_cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + - cooks_cutoff: float + The cutoff used to define the fact that a gene is a cooks outlier. + + """ + # Save these in the local adata + self.local_adata.varm["p_values"] = shared_state["p_values"] + self.local_adata.varm["wald_statistics"] = shared_state["wald_statistics"] + self.local_adata.varm["wald_se"] = shared_state["wald_se"] + + tot_num_samples = self.local_adata.uns["tot_num_samples"] + num_vars = self.local_adata.uns["n_params"] + + cooks_cutoff = f.ppf(0.99, num_vars, tot_num_samples - num_vars) + + # Take into account whether we already replaced outliers + cooks_layer = ( + "replace_cooks" + if self.refit_cooks and self.local_adata.varm["refitted"].sum() > 0 + else "cooks" + ) + + if cooks_layer == "replace_cooks": + assert "replaced" in self.local_adata.varm.keys() + replace_cooks = pd.DataFrame(self.local_adata.layers["cooks"].copy()) + replace_cooks.loc[ + self.local_adata.obsm["replaceable"], self.local_adata.varm["refitted"] + ] = 0.0 + self.local_adata.layers["replace_cooks"] = replace_cooks + + use_for_max = self.local_adata.obs["cells"].apply( + lambda x: (self.local_adata.uns["num_replicates"] >= 3).loc[x] + ) + + cooks_outliers = ( + (self.local_adata[use_for_max, :].layers[cooks_layer] > cooks_cutoff) + .any(axis=0) + .copy() + ) + + return {"local_cooks_outliers": cooks_outliers, "cooks_cutoff": cooks_cutoff} + + +class AggregateCooksOutliers: + """Mixin class to aggregate the cooks outliers. + + Methods + ------- + agg_cooks_outliers + Aggregate the local cooks outliers. + + """ + + @remote + @log_remote + @prepare_cooks_agg + def agg_cooks_outliers(self, shared_states: list[dict]) -> dict: + """ + Aggregate the local cooks outliers. + + Parameters + ---------- + shared_states : list[dict] + List of shared states from the local step with the following keys: + - local_cooks_outliers: np.ndarray of shape (n_genes,) + - cooks_cutoff: float + + Returns + ------- + shared_state : dict + Aggregated cooks outliers. + It is a dictionary with the following fields: + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier in + any of the local datasets + - cooks_cutoff: float + The cutoff used to define the fact that a gene is a cooks outlier. + """ + return { + "cooks_outliers": np.any( + [state["local_cooks_outliers"] for state in shared_states], axis=0 + ), + "cooks_cutoff": shared_states[0]["cooks_cutoff"], + } + + +class LocGetMaxCooks: + """Mixin class to get the maximum cooks distance for the outliers. + + Attributes + ---------- + local_adata : AnnData + Local AnnData object. + + Methods + ------- + get_max_local_cooks + Get the maximum cooks distance for the outliers. + + """ + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_max_local_cooks( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Get the maximum cooks distance for the outliers. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + Shared state from the previous step with the following + keys: + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + - cooks_cutoff: float + + Returns + ------- + shared_state : dict + A shared state with the following fields: + - local_max_cooks: np.ndarray of shape (n_cooks_genes,) + The maximum cooks distance for the outliers in the local dataset. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + + """ + cooks_outliers = shared_state["cooks_outliers"] + cooks_cutoff = shared_state["cooks_cutoff"] + + max_cooks = np.max(self.local_adata.layers["cooks"][:, cooks_outliers], axis=0) + + max_cooks[max_cooks <= cooks_cutoff] = 0.0 + + max_cooks_idx = self.local_adata.layers["cooks"][:, cooks_outliers].argmax( + axis=0 + ) + + max_cooks_value = self.local_adata.layers["cooks"][:, cooks_outliers][ + max_cooks_idx, np.arange(len(max_cooks)) + ] + + max_cooks_gene_counts = self.local_adata.X[:, cooks_outliers][ + max_cooks_idx, np.arange(len(max_cooks)) + ] + + # Save the max cooks gene counts and max cooks value + self.local_adata.uns["max_cooks_gene_counts"] = max_cooks_gene_counts + self.local_adata.uns["max_cooks_value"] = max_cooks_value + + return { + "local_max_cooks": max_cooks, + "cooks_outliers": cooks_outliers, + } + + +class AggMaxCooks: + """Mixin class to aggregate the max cooks distances. + + Methods + ------- + agg_max_cooks + Aggregate the local max cooks distances. + + """ + + @remote + @log_remote + def agg_max_cooks(self, shared_states: list[dict]) -> dict: + """ + Aggregate the local max cooks. + + Parameters + ---------- + shared_states : list[dict] + List of shared states from the local step with the following keys: + - local_max_cooks: np.ndarray of shape (n_genes,) + The local maximum cooks distance for the outliers. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + + Returns + ------- + shared_state : dict + Aggregated max cooks. + It is a dictionary with the following fields: + - max_cooks: np.ndarray of shape (n_cooks_genes,) + The maximum cooks distance for the outliers in the aggregated dataset. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + """ + return { + "max_cooks": np.max( + [state["local_max_cooks"] for state in shared_states], axis=0 + ), + "cooks_outliers": shared_states[0]["cooks_outliers"], + } + + +class LocGetMaxCooksCounts: + """Mixin class to get the maximum cooks counts for the outliers. + + Attributes + ---------- + local_adata : AnnData + Local AnnData object. + + Methods + ------- + get_max_local_cooks_gene_counts + Get the maximum cooks counts for the outliers. + + """ + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_max_local_cooks_gene_counts( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Get the maximum cooks counts for the outliers. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + Shared state from the previous step with the following + keys: + - max_cooks: np.ndarray of shape (n_cooks_genes,) + The maximum cooks distance for the outliers. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + + Returns + ------- + shared_state : dict + A shared state with the following fields: + - local_max_cooks_gene_counts: np.ndarray of shape (n_cooks_genes,) + For each gene, the array contains the gene counts corresponding to the + maximum cooks distance for that gene if the maximum cooks distance + in the local dataset is equal to the maximum cooks distance in the + aggregated dataset, and nan otherwise. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + + """ + max_cooks = shared_state["max_cooks"] + cooks_outliers = shared_state["cooks_outliers"] + + max_cooks_gene_counts = self.local_adata.uns["max_cooks_gene_counts"].copy() + max_cooks_value = self.local_adata.uns["max_cooks_value"].copy() + + # Remove them from the uns field as they are no longer needed + del self.local_adata.uns["max_cooks_gene_counts"] + del self.local_adata.uns["max_cooks_value"] + + max_cooks_gene_counts[ + max_cooks_value < max_cooks + ] = -1 # We can use a < because the count value are non negative integers. + + max_cooks_gene_counts_ma = np.ma.masked_array( + max_cooks_gene_counts, max_cooks_gene_counts == -1 + ) + + return { + "local_max_cooks_gene_counts": max_cooks_gene_counts_ma, + "cooks_outliers": cooks_outliers, + } + + +class AggMaxCooksCounts: + """Mixin class to aggregate the max cooks gene counts. + + Methods + ------- + agg_max_cooks_gene_counts + Aggregate the local max cooks gene counts. The goal is to have the gene + counts corresponding to the maximum cooks distance for each gene across + all datasets. + + """ + + @remote + @log_remote + def agg_max_cooks_gene_counts(self, shared_states: list[dict]) -> dict: + """ + Aggregate the local max cooks gene counts. + + Parameters + ---------- + shared_states : list[dict] + List of shared states from the local step with the following keys: + - local_max_cooks_gene_counts: np.ndarray of shape (n_genes,) + The local maximum cooks gene counts for the outliers. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + + Returns + ------- + shared_state : dict + A shared state with the following fields: + - max_cooks_gene_counts: np.ndarray of shape (n_cooks_genes,) + For each gene, the array contains the gene counts corresponding to the + maximum cooks distance for that gene across all datasets. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + + """ + return { + "max_cooks_gene_counts": np.ma.stack( + [state["local_max_cooks_gene_counts"] for state in shared_states], + axis=0, + ).min(axis=0), + "cooks_outliers": shared_states[0]["cooks_outliers"], + } + + +class LocCountNumberSamplesAbove: + """Mixin class to count the number of samples above the max cooks gene counts. + + Attributes + ---------- + local_adata : AnnData + + Methods + ------- + count_local_number_samples_above + Count the number of samples above the max cooks gene counts. + + """ + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def count_local_number_samples_above( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Count the number of samples above the max cooks gene counts. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + Shared state from the previous step with the following + keys: + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + - max_cooks_gene_counts: np.ndarray of shape (n_genes,) + For each gene, the array contains the gene counts corresponding to the + maximum cooks distance for that gene across all datasets. + + Returns + ------- + shared_state : dict + A shared state with the following fields: + - local_num_samples_above: np.ndarray of shape (n_cooks_genes,) + For each gene, the array contains the number of samples above the + maximum cooks gene counts. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + - p_values: np.ndarray of shape (n_genes,) + The p-values from the Wald test. + - wald_statistic: np.ndarray of shape (n_genes,) + The Wald statistics from the Wald test. + - wald_se: np.ndarray of shape (n_genes,) + The Wald standard errors from the Wald test. + + """ + cooks_outliers = shared_state["cooks_outliers"] + max_cooks_gene_counts = shared_state["max_cooks_gene_counts"] + + num_samples_above = np.sum( + self.local_adata.X[:, cooks_outliers] > max_cooks_gene_counts, axis=0 + ) + + return { + "local_num_samples_above": num_samples_above, + "cooks_outliers": cooks_outliers, + "p_values": self.local_adata.varm["p_values"], + "wald_statistics": self.local_adata.varm["wald_statistics"], + "wald_se": self.local_adata.varm["wald_se"], + } + + +class AggCooksFiltering: + """Mixin class to aggregate the cooks filtering. + + Methods + ------- + agg_cooks_filtering + Aggregate the local number of samples above. + + """ + + @remote + @log_remote + def agg_cooks_filtering(self, shared_states: list[dict]) -> dict: + """ + Aggregate the local number of samples above to get cooks filtered genes. + + Parameters + ---------- + shared_states : list[dict] + List of shared states from the local step with the following keys: + - local_num_samples_above: np.ndarray of shape (n_genes,) + The local number of samples above the max cooks gene counts. + - cooks_outliers: np.ndarray of shape (n_genes,) + It is a boolean array indicating whether a gene is a cooks outlier. + - p_values: np.ndarray of shape (n_genes,) + The p-values from the Wald test. + - wald_statistics: np.ndarray of shape (n_genes,) + The Wald statistics from the Wald test. + - wald_se: np.ndarray of shape (n_genes,) + The Wald standard errors from the Wald test. + + Returns + ------- + dict + A shared state with the following fields: + - p_values: np.ndarray of shape (n_genes,) + The p-values from the Wald test with nan for the cooks outliers. + - wald_statistics: np.ndarray of shape (n_genes,) + The Wald statistics. + - wald_se: np.ndarray of shape (n_genes,) + The Wald standard errors. + + """ + # Find the number of samples with counts above the max cooks + cooks_outliers = shared_states[0]["cooks_outliers"] + + num_samples_above_max_cooks = np.sum( + [state["local_num_samples_above"] for state in shared_states], axis=0 + ) + + # If that number is greater than 3, set the cooks filter to false + cooks_outliers[cooks_outliers] = num_samples_above_max_cooks < 3 + + # Set the p-values to nan on cooks outliers + p_values = shared_states[0]["p_values"] + p_values[cooks_outliers] = np.nan + + return { + "p_values": p_values, + "wald_statistics": shared_states[0]["wald_statistics"], + "wald_se": shared_states[0]["wald_se"], + } diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/deseq2_stats.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/deseq2_stats.py new file mode 100644 index 0000000..b885964 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/deseq2_stats.py @@ -0,0 +1,102 @@ +from loguru import logger + +from fedpydeseq2.core.deseq2_core.deseq2_stats.compute_padj import ( + ComputeAdjustedPValues, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering import CooksFiltering +from fedpydeseq2.core.deseq2_core.deseq2_stats.wald_tests import RunWaldTests + + +class DESeq2Stats(RunWaldTests, CooksFiltering, ComputeAdjustedPValues): + """Mixin class to compute statistics with DESeq2. + + This class encapsulates the Wald tests, the Cooks filtering and the computation + of adjusted p-values. + + Methods + ------- + run_deseq2_stats + Run the DESeq2 statistics pipeline. + Performs Wald tests, Cook's filtering and computes adjusted p-values. + + """ + + cooks_filter: bool + + def run_deseq2_stats( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """ + Run the DESeq2 statistics pipeline. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: list[dict] + Local states. Required to propagate intermediate results. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + + Returns + ------- + local_states: dict + Local states. + + round_idx: int + The updated round index. + + """ + #### Perform Wald tests #### + logger.info("Running Wald tests.") + + local_states, wald_shared_state, round_idx = self.run_wald_tests( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + logger.info("Finished running Wald tests.") + + if self.cooks_filter: + logger.info("Running Cook's filtering...") + local_states, wald_shared_state, round_idx = self.cooks_filtering( + train_data_nodes, + aggregation_node, + local_states, + wald_shared_state, + round_idx, + clean_models=clean_models, + ) + logger.info("Finished running Cook's filtering.") + logger.info("Computing adjusted p-values...") + ( + local_states, + round_idx, + ) = self.compute_adjusted_p_values( + train_data_nodes, + aggregation_node, + local_states, + wald_shared_state, + round_idx, + clean_models=clean_models, + ) + logger.info("Finished computing adjusted p-values.") + + return local_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/__init__.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/__init__.py new file mode 100644 index 0000000..000141e --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/__init__.py @@ -0,0 +1 @@ +from fedpydeseq2.core.deseq2_core.deseq2_stats.wald_tests.wald_tests import RunWaldTests diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/substeps.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/substeps.py new file mode 100644 index 0000000..d291191 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/substeps.py @@ -0,0 +1,161 @@ +from typing import Literal + +import anndata as ad +import numpy as np +from joblib import Parallel +from joblib import delayed +from joblib import parallel_backend +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms.fed_irls.utils import ( + make_irls_update_summands_and_nll_batch, +) +from fedpydeseq2.core.utils import build_contrast_vector +from fedpydeseq2.core.utils import wald_test +from fedpydeseq2.core.utils.layers import prepare_cooks_agg +from fedpydeseq2.core.utils.layers import prepare_cooks_local +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocBuildContrastVectorHMatrix: + """Mixin to get compute contrast vectors and local H matrices.""" + + local_adata: ad.AnnData + num_jobs: int + joblib_verbosity: int + joblib_backend: str + irls_batch_size: int + + @remote_data + @log_remote_data + @reconstruct_adatas + @prepare_cooks_local + def compute_contrast_vector_and_H_matrix( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Build the contrast vector and the local H matrices. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Not used. + + Returns + ------- + dict + Contains: + - local_H_matrix: np.ndarray + The local H matrix. + - LFC: np.ndarray + The log fold changes, in natural log scale. + - contrast_vector: np.ndarray + The contrast vector. + """ + # Build contrast vector and index + ( + self.local_adata.uns["contrast_vector"], + self.local_adata.uns["contrast_idx"], + ) = build_contrast_vector( + self.local_adata.uns["contrast"], + self.local_adata.varm["LFC"].columns, + ) + + # ---- Compute the summands for the covariance matrix ---- # + + with parallel_backend(self.joblib_backend): + res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)( + delayed(make_irls_update_summands_and_nll_batch)( + self.local_adata.obsm["design_matrix"].values, + self.local_adata.obsm["size_factors"], + self.local_adata.varm["LFC"][i : i + self.irls_batch_size].values, + self.local_adata.varm["dispersions"][i : i + self.irls_batch_size], + self.local_adata.X[:, i : i + self.irls_batch_size], + 0, + ) + for i in range(0, self.local_adata.n_vars, self.irls_batch_size) + ) + + H = np.concatenate([r[0] for r in res]) + + return { + "local_H_matrix": H, + "LFC": self.local_adata.varm["LFC"], + "contrast_vector": self.local_adata.uns["contrast_vector"], + } + + +class AggRunWaldTests: + """Mixin to run Wald tests.""" + + lfc_null: float + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] | None + num_jobs: int + joblib_verbosity: int + joblib_backend: str + + @remote + @log_remote + @prepare_cooks_agg + def agg_run_wald_tests(self, shared_states: list) -> dict: + """Run the Wald tests. + + Parameters + ---------- + shared_states : list + List of shared states containing: + - local_H_matrix: np.ndarray + The local H matrix. + - LFC: np.ndarray + The log fold changes, in natural log scale. + - contrast_vector: np.ndarray + The contrast vector. + + Returns + ------- + dict + Contains: + - p_values: np.ndarray + The (unadjusted) p-values (n_genes,). + - wald_statistics: np.ndarray + The Wald statistics (n_genes,). + - wald_se: np.ndarray + The standard errors of the Wald statistics (n_genes,). + """ + # First step: aggregate the local H matrices + + H = sum([state["local_H_matrix"] for state in shared_states]) + + # Second step: compute the Wald tests in parallel + with parallel_backend(self.joblib_backend): + wald_test_results = Parallel( + n_jobs=self.num_jobs, verbose=self.joblib_verbosity + )( + delayed(wald_test)( + H[i], + shared_states[0]["LFC"].values[i], + None, + shared_states[0]["contrast_vector"], + np.log(2) * self.lfc_null, + self.alt_hypothesis, + ) + for i in range(len(H)) + ) + + # Finally, unpack the results + p_values = np.array([r[0] for r in wald_test_results]) + wald_statistics = np.array([r[1] for r in wald_test_results]) + wald_se = np.array([r[2] for r in wald_test_results]) + + return { + "p_values": p_values, + "wald_statistics": wald_statistics, + "wald_se": wald_se, + } diff --git a/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/wald_tests.py b/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/wald_tests.py new file mode 100644 index 0000000..7a38aad --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/deseq2_stats/wald_tests/wald_tests.py @@ -0,0 +1,71 @@ +from fedpydeseq2.core.deseq2_core.deseq2_stats.wald_tests.substeps import ( + AggRunWaldTests, +) +from fedpydeseq2.core.deseq2_core.deseq2_stats.wald_tests.substeps import ( + LocBuildContrastVectorHMatrix, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class RunWaldTests(LocBuildContrastVectorHMatrix, AggRunWaldTests): + """Mixin class to implement the computation of the Wald tests. + + Methods + ------- + run_wald_tests + The method to compute the Wald tests. + """ + + def run_wald_tests( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Compute the Wald tests. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + """ + # --- Build contrast vectors and compute local H matrices --- # + local_states, shared_states, round_idx = local_step( + local_method=self.compute_contrast_vector_and_H_matrix, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=None, # TODO plug in previous step + aggregation_id=aggregation_node.organization_id, + description="Build contrast vectors and compute local H matrices", + clean_models=clean_models, + ) + + # --- Aggregate the H matrices and run the Wald tests --- # + wald_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_run_wald_tests, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Run Wald tests.", + clean_models=clean_models, + ) + + return local_states, wald_shared_state, round_idx diff --git a/fedpydeseq2/core/deseq2_core/replace_outliers/__init__.py b/fedpydeseq2/core/deseq2_core/replace_outliers/__init__.py new file mode 100644 index 0000000..8300296 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/replace_outliers/__init__.py @@ -0,0 +1,3 @@ +from fedpydeseq2.core.deseq2_core.replace_outliers.replace_outliers import ( + ReplaceCooksOutliers, +) diff --git a/fedpydeseq2/core/deseq2_core/replace_outliers/replace_outliers.py b/fedpydeseq2/core/deseq2_core/replace_outliers/replace_outliers.py new file mode 100644 index 0000000..66821ee --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/replace_outliers/replace_outliers.py @@ -0,0 +1,164 @@ +from fedpydeseq2.core.deseq2_core.replace_outliers.substeps import AggMergeOutlierGenes +from fedpydeseq2.core.deseq2_core.replace_outliers.substeps import AggNewAllZeros +from fedpydeseq2.core.deseq2_core.replace_outliers.substeps import LocFindCooksOutliers +from fedpydeseq2.core.deseq2_core.replace_outliers.substeps import ( + LocReplaceCooksOutliers, +) +from fedpydeseq2.core.deseq2_core.replace_outliers.substeps import ( + LocSetNewAllZerosAndGetFeatures, +) +from fedpydeseq2.core.deseq2_core.replace_outliers.substeps import LocSetRefitAdata +from fedpydeseq2.core.fed_algorithms import ComputeTrimmedMean +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ReplaceCooksOutliers( + ComputeTrimmedMean, + LocFindCooksOutliers, + AggMergeOutlierGenes, + LocReplaceCooksOutliers, + LocSetRefitAdata, + AggNewAllZeros, + LocSetNewAllZerosAndGetFeatures, +): + """Mixin class to replace Cook's outliers.""" + + trimmed_mean_num_iter: int + + def replace_outliers( + self, + train_data_nodes, + aggregation_node, + local_states, + cooks_shared_state, + round_idx, + clean_models, + ): + """Replace outlier counts. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: list[dict] + Local states. Required to propagate intermediate results. + + cooks_shared_state: dict + Shared state with the dispersion values for Cook's distances, in a + "cooks_dispersions" key. + + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states. The new local state contains Cook's distances. + + shared_states: list[dict] + List of shared states with the features vector to input to + compute_genewise_dispersion in a "local_features" key. + + round_idx: int + The updated round index. + """ + # Store trimmed means and find local Cooks outliers + local_states, shared_states, round_idx = local_step( + local_method=self.loc_find_cooks_outliers, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=cooks_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Find local Cooks outliers", + round_idx=round_idx, + clean_models=clean_models, + ) + + # Build the global list of genes for which to replace outliers + genes_to_replace_share_state, round_idx = aggregation_step( + aggregation_method=self.agg_merge_outlier_genes, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Merge the lists of local outlier genes", + round_idx=round_idx, + clean_models=clean_models, + ) + + # Store trimmed means and find local Cooks outliers + local_states, shared_states, round_idx = local_step( + local_method=self.loc_set_refit_adata, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=genes_to_replace_share_state, + aggregation_id=aggregation_node.organization_id, + description="Set the refit adata with the genes to replace", + round_idx=round_idx, + clean_models=clean_models, + ) + + # Compute imputation values, on genes to refit only. + local_states, trimmed_means_shared_state, round_idx = self.compute_trim_mean( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + layer_used="normed_counts", + trim_ratio=0.2, + mode="normal", + n_iter=self.trimmed_mean_num_iter, + refit=True, + ) + + # Replace outliers in replaceable samples locally + local_states, shared_states, round_idx = local_step( + local_method=self.loc_replace_cooks_outliers, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=trimmed_means_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Replace Cooks outliers locally", + round_idx=round_idx, + clean_models=clean_models, + ) + + # Find genes who have only have zero counts due to imputation + + new_all_zeros_shared_state, round_idx = aggregation_step( + aggregation_method=self.aggregate_new_all_zeros, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Find new all zero genes", + round_idx=round_idx, + clean_models=clean_models, + ) + + # Set new all zeros genes and get features vector + + local_states, shared_states, round_idx = local_step( + local_method=self.local_set_new_all_zeros_get_features, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=new_all_zeros_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Set new all zero genes and get features vector", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, shared_states, round_idx diff --git a/fedpydeseq2/core/deseq2_core/replace_outliers/substeps.py b/fedpydeseq2/core/deseq2_core/replace_outliers/substeps.py new file mode 100644 index 0000000..7979581 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/replace_outliers/substeps.py @@ -0,0 +1,312 @@ +import anndata as ad +import numpy as np +import pandas as pd +from anndata import AnnData +from scipy.stats import f # type: ignore +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.layers.build_layers import set_normed_counts +from fedpydeseq2.core.utils.layers.build_refit_adata import set_basic_refit_adata +from fedpydeseq2.core.utils.layers.build_refit_adata import ( + set_imputed_counts_refit_adata, +) +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocFindCooksOutliers: + """Find local Cooks outliers.""" + + local_adata: AnnData + min_replicates: int + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_find_cooks_outliers( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """ + Find local Cooks outliers by comparing the cooks distance to a threshold. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict, optional + Not used. + + Returns + ------- + dict + Shared state containing: + - "local_genes_to_replace": genes with Cook's distance above the threshold, + - "replaceable_samples": a boolean indicating whether there is at least one + sample with enough replicates to replace it. + + """ + # Find replaceable samples + n_or_more = self.local_adata.uns["num_replicates"] >= self.min_replicates + + self.local_adata.obsm["replaceable"] = n_or_more[ + self.local_adata.obs["cells"] + ].values + + # Find genes with Cook's distance above the threshold + n_params = self.local_adata.uns["n_params"] + cooks_cutoff = f.ppf( + 0.99, n_params, self.local_adata.uns["tot_num_samples"] - n_params + ) + + self.local_adata.uns["_where_cooks_g_cutoff"] = np.where( + self.local_adata.layers["cooks"] > cooks_cutoff + ) + + local_idx_to_replace = (self.local_adata.layers["cooks"] > cooks_cutoff).any( + axis=0 + ) + local_genes_to_replace = self.local_adata.var_names[local_idx_to_replace] + + return { + "local_genes_to_replace": set(local_genes_to_replace), + "replaceable_samples": self.local_adata.obsm["replaceable"].any(), + } + + +class AggMergeOutlierGenes: + """Build the global list of genes to replace.""" + + @remote + @log_remote + def agg_merge_outlier_genes( + self, + shared_states: list[dict], + ) -> dict: + """ + Merge the lists of genes to replace. + + Parameters + ---------- + shared_states : list + List of dictionaries containing: + - "local_genes_to_replace": genes with Cook's distance above the threshold, + - "replaceable_samples": a boolean indicating whether there is at least + one sample with enough replicates to replace it. + + Returns + ------- + dict + A dictionary with a unique key: "genes_to_replace" containing the list + of genes for which to replace outlier values. + """ + # If no sample is replaceable, we can skip + any_replaceable = any(state["replaceable_samples"] for state in shared_states) + + if not any_replaceable: + return {"genes_to_replace": set()} + + else: + # Take the union of all local list of genes to replace + genes_to_replace = set.union( + *[state["local_genes_to_replace"] for state in shared_states] + ) + + return { + "genes_to_replace": genes_to_replace, + } + + +class LocSetRefitAdata: + """Mixin to replace cooks outliers locally.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_set_refit_adata( + self, + data_from_opener, + shared_state: dict, + ) -> None: + """ + Set a refit adata containing the counts of the genes to replace. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + A dictionary with a "genes_to_replace" key, containing the list of genes + for which to replace outlier values. + """ + # Save the information on which genes will be replaced + genes_to_replace = pd.Series(False, index=self.local_adata.var_names) + genes_to_replace[list(shared_state["genes_to_replace"])] = True + self.local_adata.varm["replaced"] = genes_to_replace.values + + # Copy the values corresponding to the genes to refit in the refit_adata + set_basic_refit_adata(self) + + +class LocReplaceCooksOutliers: + """Mixin to replace cooks outliers locally.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_replace_cooks_outliers( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """ + Replace outlier counts with imputed values. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + A dictionary with a "trimmed_mean_normed_counts" key, containing the + trimmed means to use to compute the imputed values. + + Returns + ------- + dict + A dictionary containing: + - "loc_new_all_zero": a boolean array indicating which genes are now + all-zero. + """ + # Set the trimmed mean normed counts in the varm + self.refit_adata.varm["_trimmed_mean_normed_counts"] = shared_state[ + "trimmed_mean_normed_counts" + ] + + set_imputed_counts_refit_adata(self) + + # Find new all-zero columns + new_all_zeroes = self.refit_adata.X.sum(axis=0) == 0 + + # Return the new local logmeans + with np.errstate(divide="ignore"): # ignore division by zero warnings + return { + "loc_new_all_zeroes": new_all_zeroes, + } + + +class AggNewAllZeros: + """Mixin to compute the new all zeros and share to the centers.""" + + @remote + @log_remote + def aggregate_new_all_zeros(self, shared_states: list) -> dict: + """Compute the global mean given the local results. + + Parameters + ---------- + shared_states : list + List of results (local_mean, n_samples) from training nodes. + In refit mode, also contains "loc_new_all_zero". + + Returns + ------- + dict + New all-zero genes. + """ + # Find genes that are all zero due to imputation of counts + new_all_zeroes = np.all( + [state["loc_new_all_zeroes"] for state in shared_states], axis=0 + ) + + return {"new_all_zeroes": new_all_zeroes} + + +class LocSetNewAllZerosAndGetFeatures: + """Mixin to set the new all zeros and return local features. + + This Mixin implements the method to perform the transition towards the + compute_rough_dispersions steps after refitting. It sets the new all zeros + genes in the local AnnData and computes the local features to be shared + to the aggregation node. + + Methods + ------- + local_set_new_all_zeros_get_features + The method to set the new all zeros genes and compute the local features. + + """ + + local_adata: ad.AnnData + refit_adata: ad.AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_set_new_all_zeros_get_features( + self, + data_from_opener, + shared_state, + ) -> dict: + """ + Set the new_all_zeros field and get the features. + + This method is used to set the new_all_zeros field in the local_adata uns + field. This is the set of genes that are all zero after outlier replacement. + + It then restricts the refit_adata to the genes which are not all_zero. + + Finally, it computes the local features to be shared via shared_state to the + aggregation node. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state containing the "new_all_zeroes" key. + + Returns + ------- + dict + Local feature vector to be shared via shared_state to + the aggregation node. + """ + # Take all-zero genes into account + new_all_zeroes = shared_state["new_all_zeroes"] + + self.local_adata.uns["new_all_zeroes_genes"] = self.refit_adata.var_names[ + new_all_zeroes + ] + + self.local_adata.varm["refitted"] = self.local_adata.varm["replaced"].copy() + # Only replace if genes are not all zeroes after outlier replacement + self.local_adata.varm["refitted"][ + self.local_adata.varm["refitted"] + ] = ~new_all_zeroes + + # RESTRICT REFIT ADATA TO NOT NEW ALL ZEROES + self.refit_adata = self.refit_adata[:, ~new_all_zeroes].copy() + + # Update normed counts + set_normed_counts(self.refit_adata) + + #### ---- Compute Gram matrix and feature vector ---- #### + + design = self.refit_adata.obsm["design_matrix"].values + + return { + "local_features": design.T @ self.refit_adata.layers["normed_counts"], + } diff --git a/fedpydeseq2/core/deseq2_core/replace_refitted_values/__init__.py b/fedpydeseq2/core/deseq2_core/replace_refitted_values/__init__.py new file mode 100644 index 0000000..3108bf4 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/replace_refitted_values/__init__.py @@ -0,0 +1,3 @@ +from fedpydeseq2.core.deseq2_core.replace_refitted_values.replace_refitted_values import ( # noqa: E501 + ReplaceRefittedValues, +) diff --git a/fedpydeseq2/core/deseq2_core/replace_refitted_values/replace_refitted_values.py b/fedpydeseq2/core/deseq2_core/replace_refitted_values/replace_refitted_values.py new file mode 100644 index 0000000..a386149 --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/replace_refitted_values/replace_refitted_values.py @@ -0,0 +1,99 @@ +import anndata as ad +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data + + +class ReplaceRefittedValues: + """Mixin class to replace refitted values.""" + + local_adata: ad.AnnData | None + refit_adata: ad.AnnData | None + + def replace_refitted_values( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Replace the values that were refitted in `local_adata`s. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: list[Dict] + Local states. Required to propagate intermediate results. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states, with refitted values + + round_idx: int + The updated round index. + """ + local_states, shared_states, round_idx = local_step( + local_method=self.loc_replace_refitted_values, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Replace refitted values in local adatas", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, round_idx + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_replace_refitted_values(self, data_from_opener, shared_state): + """ + Replace refitted values in local_adata from refit_adata. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Not used. + """ + # Replace values in main object + list_varm_keys = [ + "_normed_means", + "LFC", + "genewise_dispersions", + "fitted_dispersions", + "MAP_dispersions", + "dispersions", + ] + for key in list_varm_keys: + self.local_adata.varm[key][ + self.local_adata.varm["refitted"] + ] = self.refit_adata.varm[key] + + # Take into account new all-zero genes + new_all_zeroes_genes = self.local_adata.uns["new_all_zeroes_genes"] + if len(new_all_zeroes_genes) > 0: + self.local_adata.varm["_normed_means"][ + self.local_adata.var_names.get_indexer(new_all_zeroes_genes) + ] = 0 + self.local_adata.varm["LFC"].loc[new_all_zeroes_genes, :] = 0 diff --git a/fedpydeseq2/core/deseq2_core/save_pipeline_results.py b/fedpydeseq2/core/deseq2_core/save_pipeline_results.py new file mode 100644 index 0000000..58bb3fd --- /dev/null +++ b/fedpydeseq2/core/deseq2_core/save_pipeline_results.py @@ -0,0 +1,155 @@ +"""Module to implement Mixin to get results as a shared state.""" + +import anndata as ad +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults + + +class SavePipelineResults(AggPassOnResults): + """Mixin class to save pipeline results. + + Attributes + ---------- + local_adata : AnnData + Local AnnData object. + + results : dict + Results to share. + + VARM_KEYS : list + List of keys to extract from the varm attribute. + + UNS_KEYS : list + List of keys to extract from the uns attribute. + + Methods + ------- + save_pipeline_results + Save the pipeline results. + These results will be downloaded at the end of the pipeline. + They are defined using the VARM_KEYS and UNS_KEYS attributes. + + get_results_from_local_states + Get the results to share from the local states. + + """ + + local_adata: ad.AnnData + results: dict | None + + VARM_KEYS = [ + "MAP_dispersions", + "dispersions", + "genewise_dispersions", + "non_zero", + "fitted_dispersions", + "LFC", + "padj", + "p_values", + "wald_statistics", + "wald_se", + "replaced", + "refitted", + ] + + UNS_KEYS = [ + "prior_disp_var", + "_squared_logres", + "contrast", + ] + + def save_pipeline_results( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Build the results that will be downloaded at the end of the pipeline. + + Parameters + ---------- + train_data_nodes: list[TrainDataNode] + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + Index of the current round. + + clean_models: bool + Whether to clean the models after the computation. + + """ + local_states, shared_states, round_idx = local_step( + local_method=self.get_results_from_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get results to share from the local centers", + round_idx=round_idx, + clean_models=clean_models, + ) + + # Build the global list of genes for which to replace outliers + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Merge the lists of results and return output", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_results_from_local_states( + self, + data_from_opener, + shared_state: dict | None, + ) -> dict: + """ + Get the results to share from the local states. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict, optional + Not used. + + Returns + ------- + dict + Shared state containing the gene names, as well + as selected fields from the varm and uns attributes. + + """ + shared_state = { + "gene_names": self.local_adata.var_names, + } + for varm_key in self.VARM_KEYS: + if varm_key in self.local_adata.varm.keys(): + shared_state[varm_key] = self.local_adata.varm[varm_key] + else: + shared_state[varm_key] = None + + for uns_key in self.UNS_KEYS: + shared_state[uns_key] = self.local_adata.uns[uns_key] + + return shared_state diff --git a/fedpydeseq2/core/deseq2_strategy.py b/fedpydeseq2/core/deseq2_strategy.py new file mode 100644 index 0000000..c2eb07f --- /dev/null +++ b/fedpydeseq2/core/deseq2_strategy.py @@ -0,0 +1,381 @@ +import pickle as pkl +from pathlib import Path +from typing import Any +from typing import Literal + +import anndata as ad +from substrafl import ComputePlanBuilder +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef + +from fedpydeseq2.core.deseq2_core import DESeq2FullPipe +from fedpydeseq2.core.utils.logging import log_save_local_state + + +class DESeq2Strategy(ComputePlanBuilder, DESeq2FullPipe): + """DESeq2 strategy. + + This strategy is an implementation of the DESeq2 algorithm. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + contrast : list or None + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'ref_level']``. + Names must correspond to the metadata data passed to the DeseqDataSet. + E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B' compared + to 'condition A'. + For continuous variables, the last two strings should be left empty, e.g. + ``['measurement', '', ''].`` + If None, the last variable from the design matrix is chosen + as the variable of interest, and the reference level is picked alphabetically. + (default: ``None``). + + lfc_null : float + The (log2) log fold change under the null hypothesis. (default: ``0``). + + alt_hypothesis : str or None + The alternative hypothesis for computing wald p-values. By default, the normal + Wald test assesses deviation of the estimated log fold change from the null + hypothesis, as given by ``lfc_null``. + One of ``["greaterAbs", "lessAbs", "greater", "less"]`` or ``None``. + The alternative hypothesis corresponds to what the user wants to find rather + than the null hypothesis. (default: ``None``). + + min_replicates : int + Minimum number of replicates a condition should have + to allow refitting its samples. (default: ``7``). + + min_disp : float + Lower threshold for dispersion parameters. (default: ``1e-8``). + + max_disp : float + Upper threshold for dispersion parameters. + Note: The threshold that is actually enforced is max(max_disp, len(counts)). + (default: ``10``). + + grid_batch_size : int + The number of genes to put in each batch for local parallel processing. + (default: ``100``). + + grid_depth : int + The number of grid interval selections to perform (if using GridSearch). + (default: ``3``). + + grid_length : int + The number of grid points to use for the grid search (if using GridSearch). + (default: ``100``). + + num_jobs : int + The number of jobs to use for local parallel processing in MLE tasks. + (default: ``8``). + + independent_filter : bool + Whether to perform independent filtering to correct p-value trends. + (default: ``True``). + + alpha : float + P-value and adjusted p-value significance threshold (usually 0.05). + (default: ``0.05``). + + min_mu : float + The minimum value of the mean parameter mu. (default: ``0.5``). + + beta_tol : float + The tolerance for the beta parameter. (default: ``1e-8``). This is used + in the IRLS algorithm to stop the iterations when the relative change in + the beta parameter is smaller than beta_tol. + + max_beta : float + The maximum value for the beta parameter. (default: ``30``). + + irls_num_iter : int + The number of iterations to perform in the IRLS algorithm. (default: ``20``). + + joblib_backend : str + The backend to use for parallel processing. (default: ``loky``). + + joblib_verbosity : int + The verbosity level of joblib. (default: ``3``). + + irls_batch_size : int + The number of genes to put in each batch for local parallel processing in the + IRLS algorithm. (default: ``100``). + + PQN_c1 : float + The Armijo line search constant for the prox newton. + + PQN_ftol : float + The functional stopping criterion for the prox newton method (relative error + smaller than ftol). + + PQN_num_iters_ls : int + The number of iterations performed in the line search at each prox newton step. + + PQN_num_iters : int, + The number of iterations in the prox newton catch of IRLS. + + PQN_min_mu : float + The minimum value for mu in the prox newton method. + + refit_cooks : bool + Whether to refit the model after computation of Cooks distance. + (default: ``True``). + + cooks_filter : bool + Whether to filter out genes with high Cooks distance in the pvalue computation. + (default: ``True``). + + save_layers_to_disk : bool + Whether to save the layers to disk. (default: ``False``). + If True, the layers will be saved to disk. + + trimmed_mean_num_iter: int + The number of iterations to use when computing the trimmed mean + in a federated way, i.e. the number of dichotomy steps. The default is + 40. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + contrast: list[str] | None = None, + lfc_null: float = 0.0, + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] + | None = None, + min_replicates: int = 7, + min_disp: float = 1e-8, + max_disp: float = 10.0, + grid_batch_size: int = 250, + grid_depth: int = 3, + grid_length: int = 100, + num_jobs=8, + min_mu: float = 0.5, + beta_tol: float = 1e-8, + max_beta: float = 30, + irls_num_iter: int = 20, + joblib_backend: str = "loky", + joblib_verbosity: int = 0, + irls_batch_size: int = 100, + independent_filter: bool = True, + alpha: float = 0.05, + PQN_c1: float = 1e-4, + PQN_ftol: float = 1e-7, + PQN_num_iters_ls: int = 20, + PQN_num_iters: int = 100, + PQN_min_mu: float = 0.0, + refit_cooks: bool = True, + cooks_filter: bool = True, + save_layers_to_disk: bool = False, + trimmed_mean_num_iter: int = 40, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + contrast=contrast, + lfc_null=lfc_null, + alt_hypothesis=alt_hypothesis, + min_replicates=min_replicates, + min_disp=min_disp, + max_disp=max_disp, + grid_batch_size=grid_batch_size, + grid_depth=grid_depth, + grid_length=grid_length, + num_jobs=num_jobs, + min_mu=min_mu, + beta_tol=beta_tol, + max_beta=max_beta, + irls_num_iter=irls_num_iter, + joblib_backend=joblib_backend, + joblib_verbosity=joblib_verbosity, + irls_batch_size=irls_batch_size, + independent_filter=independent_filter, + alpha=alpha, + PQN_c1=PQN_c1, + PQN_ftol=PQN_ftol, + PQN_num_iters_ls=PQN_num_iters_ls, + PQN_num_iters=PQN_num_iters, + PQN_min_mu=PQN_min_mu, + refit_cooks=refit_cooks, + cooks_filter=cooks_filter, + trimmed_mean_num_iter=trimmed_mean_num_iter, + ) + + #### Define hyper parameters #### + + self.min_disp = min_disp + self.max_disp = max_disp + self.grid_batch_size = grid_batch_size + self.grid_depth = grid_depth + self.grid_length = grid_length + self.min_mu = min_mu + self.beta_tol = beta_tol + self.max_beta = max_beta + + # Parameters of the IRLS algorithm + self.irls_num_iter = irls_num_iter + self.min_replicates = min_replicates + self.PQN_c1 = PQN_c1 + self.PQN_ftol = PQN_ftol + self.PQN_num_iters_ls = PQN_num_iters_ls + self.PQN_num_iters = PQN_num_iters + self.PQN_min_mu = PQN_min_mu + + # Parameters for the trimmed mean computation + self.trimmed_mean_num_iter = trimmed_mean_num_iter + + #### Stat parameters + self.independent_filter = independent_filter + self.alpha = alpha + + #### Define job parallelization parameters #### + + self.num_jobs = num_jobs + self.joblib_verbosity = joblib_verbosity + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + #### Define quantities to set the design #### + + # Convert design_factors to list if a single string was provided. + self.design_factors = ( + [design_factors] if isinstance(design_factors, str) else design_factors + ) + + self.ref_levels = ref_levels + self.continuous_factors = continuous_factors + + if self.continuous_factors is not None: + self.categorical_factors = [ + factor + for factor in self.design_factors + if factor not in self.continuous_factors + ] + else: + self.categorical_factors = self.design_factors + + self.contrast = contrast + + #### Set test parameters #### + self.lfc_null = lfc_null + self.alt_hypothesis = alt_hypothesis + + #### If we want to refit cooks outliers + self.refit_cooks = refit_cooks + + #### Define quantities to compute statistics + self.cooks_filter = cooks_filter + + #### Set attributes to be registered / saved later on #### + self.local_adata: ad.AnnData | None = None + self.refit_adata: ad.AnnData | None = None + self.results: dict | None = None + + #### Save layers to disk + self.save_layers_to_disk = save_layers_to_disk + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``True``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + self.run_deseq_pipe( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + @log_save_local_state + def save_local_state(self, path: Path) -> None: + """Save the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to save the state. Automatically handled by subtrafl. + """ + state_to_save = { + "local_adata": self.local_adata, + "refit_adata": self.refit_adata, + "results": self.results, + } + with open(path, "wb") as file: + pkl.dump(state_to_save, file) + + def load_local_state(self, path: Path) -> Any: + """Load the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to load the state from. Automatically handled by + subtrafl. + """ + with open(path, "rb") as file: + state_to_load = pkl.load(file) + + self.local_adata = state_to_load["local_adata"] + self.refit_adata = state_to_load["refit_adata"] + self.results = state_to_load["results"] + + return self + + @property + def num_round(self): + """Return the number of round in the strategy. + + TODO do something clever with this. + + Returns + ------- + int + Number of round in the strategy. + """ + return None diff --git a/fedpydeseq2/core/fed_algorithms/__init__.py b/fedpydeseq2/core/fed_algorithms/__init__.py new file mode 100644 index 0000000..3eaf4ef --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/__init__.py @@ -0,0 +1,17 @@ +"""Module containing the methods that are used multiple times in the pipeline. + +These methods have been adapted from a pooled setting to a federated setting. +They are: +- the computation of the trimmed mean +- the federated IRLS computation with a negative binomial distribution +- the federated Proximal Quasi Newton computation with a negative binomial distribution +- the federated grid search computation with a negative binomial distribution for the + alpha parameter. +""" + +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean import ComputeTrimmedMean +from fedpydeseq2.core.fed_algorithms.dispersions_grid_search import ( + ComputeDispersionsGridSearch, +) +from fedpydeseq2.core.fed_algorithms.fed_irls import FedIRLS +from fedpydeseq2.core.fed_algorithms.fed_PQN import FedProxQuasiNewton diff --git a/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/__init__.py b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/__init__.py new file mode 100644 index 0000000..bc6c328 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/__init__.py @@ -0,0 +1,3 @@ +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.compute_trimmed_mean import ( + ComputeTrimmedMean, +) diff --git a/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/compute_trimmed_mean.py b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/compute_trimmed_mean.py new file mode 100644 index 0000000..dacb093 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/compute_trimmed_mean.py @@ -0,0 +1,194 @@ +"""Module containing the steps to compute trimmed mean.""" +from typing import Literal + +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.substeps import ( + AggFinalTrimmedMean, +) +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.substeps import ( + AggInitTrimmedMean, +) +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.substeps import ( + AggIterationTrimmedMean, +) +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.substeps import ( + LocalIterationTrimmedMean, +) +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.substeps import ( + LocFinalTrimmedMean, +) +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.substeps import ( + LocInitTrimmedMean, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeTrimmedMean( + LocInitTrimmedMean, + AggInitTrimmedMean, + LocalIterationTrimmedMean, + AggIterationTrimmedMean, + LocFinalTrimmedMean, + AggFinalTrimmedMean, +): + """Strategy to compute the trimmed mean.""" + + def compute_trim_mean( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx: int, + clean_models: bool, + layer_used: str, + mode: Literal["normal", "cooks"] = "normal", + trim_ratio: float | None = None, + n_iter: int = 50, + refit: bool = False, + min_replicates_trimmed_mean: int = 3, + ): + """ + Run the trimmed mean computation on the layer specified. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + layer_used : str + The layer used to compute the trimmed mean. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The current round. + + clean_models: bool + If True, the models are cleaned. + + mode : Literal["normal", "cooks"] + The mode to use. If "cooks", the local trimmed mean is actually computed + per level, and predefined trim ratios are applied, as well as certain + scaling factors on the outputed means. + If "normal", the local trimmed mean is computed on the whole dataset, as + expected, using the trim_ratio parameter. + + trim_ratio : float or None + The ratio to trim. Should be between 0 and 0.5. + Is only used in "normal" mode, and should be None in "cooks" mode. + + n_iter : int + The number of iterations. + + refit : bool + If True, the function will compute the trimmed mean on the refit adata only. + + min_replicates_trimmed_mean : int + The minimum number of replicates to compute the trimmed mean. + + Returns + ------- + local_states: list[dict] + Local states dictionaries. + + final_trimmed_mean_agg_share_state: dict + Dictionary containing the final trimmed mean aggregation share + state in a field "trimmed_mean_". + + round_idx: int + + """ + if mode == "cooks": + assert trim_ratio is None, "trim_ratio should be None in cooks mode" + + local_states, shared_states, round_idx = local_step( + local_method=self.loc_init_trimmed_mean, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize trim mean", + round_idx=round_idx, + clean_models=clean_models, + method_params={ + "layer_used": layer_used, + "mode": mode, + "refit": refit, + "min_replicates_trimmed_mean": min_replicates_trimmed_mean, + }, + ) + + aggregation_share_state, round_idx = aggregation_step( + aggregation_method=self.agg_init_trimmed_mean, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Aggregation init of trimmed mean", + round_idx=round_idx, + clean_models=clean_models, + ) + + for _ in range(n_iter): + local_states, shared_states, round_idx = local_step( + local_method=self.local_iteration_trimmed_mean, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=aggregation_share_state, + aggregation_id=aggregation_node.organization_id, + description="Local iteration of trimmed mean", + round_idx=round_idx, + clean_models=clean_models, + method_params={ + "layer_used": layer_used, + "mode": mode, + "trim_ratio": trim_ratio, + "refit": refit, + }, + ) + + aggregation_share_state, round_idx = aggregation_step( + aggregation_method=self.agg_iteration_trimmed_mean, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Aggregation iteration of trimmed mean", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.final_local_trimmed_mean, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=aggregation_share_state, + aggregation_id=aggregation_node.organization_id, + description="Final local step of trimmed mean", + round_idx=round_idx, + clean_models=clean_models, + method_params={ + "layer_used": layer_used, + "trim_ratio": trim_ratio, + "mode": mode, + "refit": refit, + }, + ) + + final_trimmed_mean_agg_share_state, round_idx = aggregation_step( + aggregation_method=self.final_agg_trimmed_mean, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="final aggregation of trimmed mean", + round_idx=round_idx, + clean_models=clean_models, + method_params={"layer_used": layer_used, "mode": mode}, + ) + + return local_states, final_trimmed_mean_agg_share_state, round_idx diff --git a/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/substeps.py b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/substeps.py new file mode 100644 index 0000000..852fde0 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/substeps.py @@ -0,0 +1,869 @@ +"""Module to implement the substeps for comuting the trimmed mean. + +This module contains all these substeps as mixin classes. +""" +from typing import Literal + +import numpy as np +import pandas as pd +from anndata import AnnData +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.utils import get_scale +from fedpydeseq2.core.fed_algorithms.compute_trimmed_mean.utils import get_trim_ratio +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocInitTrimmedMean: + """Mixin class to implement the local initialisation of the trimmed mean algo.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_init_trimmed_mean( + self, + data_from_opener, + shared_state, + layer_used: str, + mode: Literal["normal", "cooks"] = "normal", + refit: bool = False, + min_replicates_trimmed_mean: int = 3, + ) -> dict: + """ + Initialise the trimmed mean algo, by providing the lower and max bounds. + + Parameters + ---------- + data_from_opener : AnnData + Unused, all the necessary info is stored in the local adata. + + shared_state : dict + Not used, all the necessary info is stored in the local adata. + + layer_used : str + Name of the layer used to compute the trimmed mean. + + mode : Literal["normal", "cooks"] + Mode of the trimmed mean algo. If "cooks", the function will be applied + either on the normalized counts or the squared error. + It will be applied per level, except if there are not enough samples. + + refit : bool + If true, the function will use the refit adata to compute the trimmed mean. + + min_replicates_trimmed_mean : int + Minimum number of replicates to compute the trimmed mean. + + Returns + ------- + dict + If mode is "normal" or if mode is "cooks" and there are not enough samples, + to compute the trimmed mean per level, a dictionary with the following keys + - max_values: np.ndarray of size (n_genes,) + - min_values: np.ndarray of size (n_genes,) + - use_lvl: False + otherwise, a dictionary with the max_values and min_values keys, nested + inside a dictionary with the levels as keys, plus a use_lvl with value True + """ + if refit: + adata = self.refit_adata + else: + adata = self.local_adata + + if mode == "cooks": + # Check that the layer is either cooks or normed counts + assert layer_used in ["sqerror", "normed_counts"] + # Check that num replicates is in the uns + assert "num_replicates" in adata.uns, "No num_replicates in the adata" + use_lvl = adata.uns["num_replicates"].max() >= min_replicates_trimmed_mean + assert "cells" in adata.obs, "No cells column in the adata" + + else: + use_lvl = False + result = {"use_lvl": use_lvl} + if use_lvl: + # In that case, we know we are in cooks mode + admissible_levels = adata.uns["num_replicates"][ + adata.uns["num_replicates"] >= min_replicates_trimmed_mean + ].index + + shared_state = {lvl: shared_state for lvl in admissible_levels} + for lvl in admissible_levels: + mask = adata.obs["cells"] == lvl + result[lvl] = self.loc_init_trimmed_mean_per_lvl( # type: ignore + data_from_opener, shared_state[lvl], layer_used, mask, refit + ) + return result + else: + result.update( + self.loc_init_trimmed_mean_per_lvl( + data_from_opener, + shared_state, + layer_used, + mask=np.ones(adata.n_obs, dtype=bool), + refit=refit, + ) + ) + return result + + def loc_init_trimmed_mean_per_lvl( + self, + data_from_opener, + shared_state, + layer_used: str, + mask, + refit: bool = False, + ) -> dict: + """ + Initialise the trimmed mean algo, by providing the lower and max bounds. + + Parameters + ---------- + data_from_opener : AnnData + Unused, all the necessary info is stored in the local adata. + + shared_state : dict + Not used, all the necessary info is stored in the local adata. + + layer_used : str + Name of the layer used to compute the trimmed mean. + + mask : np.ndarray + Mask to filter values used in the min and max computation. + + refit : bool + If true, the function will use the refit adata to compute the trimmed mean. + + Returns + ------- + dict + Dictionary with the following keys + - max_values: np.ndarray of size (n_genes,) + - min_values: np.ndarray of size (n_genes,) + - n_samples: int, number of samples + """ + if refit: + adata = self.refit_adata + else: + adata = self.local_adata + + assert layer_used in adata.layers + local_adata_filtered = adata[mask] + if local_adata_filtered.n_obs > 0: + max_values = local_adata_filtered.layers[layer_used].max(axis=0) + min_values = local_adata_filtered.layers[layer_used].min(axis=0) + else: + max_values = np.zeros(adata.n_vars) * np.nan + min_values = np.zeros(adata.n_vars) * np.nan + return { + "max_values": max_values, + "min_values": min_values, + } + + +class AggInitTrimmedMean: + """Mixin class for the aggregation of the init of the trimmed mean algo.""" + + @remote + @log_remote + def agg_init_trimmed_mean( + self, + shared_states: list[dict], + ) -> dict: + """ + Compute the initial global upper and lower bounds. + + Parameters + ---------- + shared_states : list[dict] + If use_lvl is False (in any shared state), + list of dictionaries with the following keys: + - max_values: np.ndarray of size (n_genes,) + - min_values: np.ndarray of size (n_genes,) + If use_lvl is True, list of dictionaries with the same keys as above + nested inside a dictionary with the levels as keys. + + Returns + ------- + dict + use_level is a key present in all input shared states, and will be passed + on to the output shared state + If use_lvl is False, dict with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes, 2) + - lower_bounds_thresholds : np.ndarray of size (n_genes, 2) + otherwise, a dictionary with the same keys for nested inside a dictionary + with the levels as keys. + """ + use_lvl = shared_states[0]["use_lvl"] + result = {"use_lvl": use_lvl} + if use_lvl: + for lvl in shared_states[0].keys(): + if lvl == "use_lvl": + continue + result[lvl] = self.agg_init_trimmed_mean_per_lvl( + [state[lvl] for state in shared_states] + ) + return result + else: + result.update(self.agg_init_trimmed_mean_per_lvl(shared_states)) + return result + + def agg_init_trimmed_mean_per_lvl(self, shared_states: list[dict]) -> dict: + """ + Compute the initial global upper and lower bounds. + + Parameters + ---------- + shared_states : list[dict] + List of dictionaries with the following keys: + - max_values: np.ndarray of size (n_genes,) + - min_values: np.ndarray of size (n_genes,) + + Returns + ------- + dict + dict with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes, 2) + - lower_bounds_thresholds : np.ndarray of size (n_genes, 2) + """ + # To initialize the dichotomic search of the quantile thresholds, we need to + # set the upper and lower bounds of the thresholds. + upper_bounds_thresholds = np.nanmax( + np.array([state["max_values"] for state in shared_states]), axis=0 + ) + lower_bounds_thresholds = np.nanmin( + np.array([state["min_values"] for state in shared_states]), axis=0 + ) + + # We are looking for two thresholds, one for the upper quantile and one for the + # lower quantile. We initialize the search with the same value for both. + upper_bounds_thresholds = np.vstack([upper_bounds_thresholds] * 2).T + lower_bounds_thresholds = np.vstack([lower_bounds_thresholds] * 2).T + + upper_bounds_thresholds = upper_bounds_thresholds.astype(np.float32) + lower_bounds_thresholds = lower_bounds_thresholds.astype(np.float32) + + return { + "upper_bounds_thresholds": upper_bounds_thresholds, + "lower_bounds_thresholds": lower_bounds_thresholds, + } + + +class LocalIterationTrimmedMean: + """Mixin class to implement the local iteration of the trimmed mean algo.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_iteration_trimmed_mean( + self, + data_from_opener, + shared_state, + layer_used: str, + mode: Literal["normal", "cooks"] = "normal", + trim_ratio: float | None = None, + refit: bool = False, + ) -> dict: + """ + Local iteration of the trimmed mean algo. + + Parameters + ---------- + data_from_opener : AnnData + Not used, all the necessary info is stored in the local adata. + + shared_state : dict + Dictionary with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes,2). Not used. + - lower_bounds_thresholds : np.ndarray of size (n_genes,2). Not used. + If use_lvl is true, the dictionary is nested with the levels as keys. + + layer_used : str + Name of the layer used to compute the trimmed mean. + + mode : Literal["normal", "cooks"] + Mode of the trimmed mean algo. If "cooks", the function will be applied + either on the normalized counts or the squared error. + It will be applied per level, except if there are not enough samples. + Moreover, trim ratios will be computed based on the number of replicates. + If "normal", the function will be applied on the whole dataset, using the + trim_ratio parameter. + + trim_ratio : float or None + Ratio of the samples to be trimmed. Must be between 0 and 0.5. Must be + None if mode is "cooks", and float if mode is "normal". + + refit : bool + If true, the function will use the refit adata to compute the trimmed mean. + + Returns + ------- + dict + Dictionary containing the following keys: + - num_strictly_above: np.ndarray[int] of size (n_genes,2) + - upper_bounds_thresholds: np.ndarray of size (n_genes,2) + - lower_bounds_thresholds: np.ndarray of size (n_genes,2) + - n_samples: int + - trim_ratio: float + If use_lvl is true, the dictionary is nested with the levels as keys. + + """ + if refit: + adata = self.refit_adata + else: + adata = self.local_adata + use_lvl = shared_state["use_lvl"] + result = {"use_lvl": use_lvl} + + if mode == "cooks": + assert trim_ratio is None + else: + assert trim_ratio is not None + + if mode == "cooks" and use_lvl: + for lvl in shared_state.keys(): + if lvl == "use_lvl": + continue + mask = adata.obs["cells"] == lvl + result[lvl] = self.local_iteration_trimmed_mean_per_lvl( + data_from_opener, shared_state[lvl], layer_used, mask, refit + ) + trim_ratio = get_trim_ratio(adata.uns["num_replicates"][lvl]) + result[lvl]["trim_ratio"] = trim_ratio + return result + + result.update( + self.local_iteration_trimmed_mean_per_lvl( + data_from_opener, + shared_state, + layer_used, + mask=np.ones(adata.n_obs, dtype=bool), + refit=refit, + ) + ) + if mode == "cooks": + result["trim_ratio"] = 0.125 + else: + result["trim_ratio"] = trim_ratio + + return result + + def local_iteration_trimmed_mean_per_lvl( + self, data_from_opener, shared_state, layer_used, mask, refit: bool = False + ) -> dict: + """ + Local iteration of the trimmed mean algo. + + Parameters + ---------- + data_from_opener : AnnData + Not used, all the necessary info is stored in the local adata. + + shared_state : dict + Dictionary with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes,2). Not used. + - lower_bounds_thresholds : np.ndarray of size (n_genes,2). Not used. + + layer_used : str + Name of the layer used to compute the trimmed mean. + + mask : np.ndarray + Mask to filter values used in the quantile computation. + + refit : bool + If true, the function will use the refit adata to compute the trimmed mean. + + Returns + ------- + dict + Dictionary containing the following keys: + - num_strictly_above: np.ndarray[int] of size (n_genes,2) + - upper_bounds_thresholds: np.ndarray of size (n_genes,2) + - lower_bounds_thresholds: np.ndarray of size (n_genes,2) + - n_samples: int + + """ + if refit: + adata = self.refit_adata + else: + adata = self.local_adata + + # We don't need to pass the thresholds in the share states as it's always the + # mean of the upper and lower bounds. + threshold = ( + shared_state["upper_bounds_thresholds"] + + shared_state["lower_bounds_thresholds"] + ) / 2 + local_adata_filtered = adata[mask] + # Array of size (n_genes, 2) containing the number of samples above the + # thresholds. + num_strictly_above = ( + local_adata_filtered.layers[layer_used][..., None] > threshold[None, ...] + ).sum(axis=0) + + return { + "num_strictly_above": num_strictly_above, + "upper_bounds_thresholds": shared_state["upper_bounds_thresholds"], + "lower_bounds_thresholds": shared_state["lower_bounds_thresholds"], + "n_samples": local_adata_filtered.n_obs, + } + + +class AggIterationTrimmedMean: + """Mixin class of the aggregation of the iteration of the trimmed mean algo.""" + + @remote + @log_remote + def agg_iteration_trimmed_mean( + self, + shared_states: list[dict], + ) -> dict: + """ + Compute the initial global upper and lower bounds. + + Parameters + ---------- + shared_states : list[dict] + List of dictionnaries with the following keys: + - num_strictly_above: np.ndarray[int] of size (n_genes,2) + - upper_bounds_thresholds: np.ndarray of size (n_genes,2) + - lower_bounds_thresholds: np.ndarray of size (n_genes,2) + - n_samples: int + - trim_ratio: float + If use_lvl is true, the dictionary is nested with the levels as keys. + + Returns + ------- + dict + Dictionary with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes, 2) + - lower_bounds_thresholds : np.ndarray of size (n_genes, 2) + If use_lvl is true, the dictionary is nested with the levels as keys. + """ + use_lvl = shared_states[0]["use_lvl"] + result = {"use_lvl": use_lvl} + if use_lvl: + for lvl in shared_states[0].keys(): + if lvl == "use_lvl": + continue + result[lvl] = self.agg_iteration_trimmed_mean_per_lvl( + [state[lvl] for state in shared_states] + ) + return result + else: + result.update(self.agg_iteration_trimmed_mean_per_lvl(shared_states)) + return result + + def agg_iteration_trimmed_mean_per_lvl( + self, + shared_states: list[dict], + ) -> dict: + """ + Aggregate step of the iteration of the trimmed mean algo. + + Parameters + ---------- + shared_states : list[dict] + List of dictionary containing the following keys: + - num_strictly_above: np.ndarray[int] of size (n_genes,2) + - upper_bounds_thresholds: np.ndarray of size (n_genes,2) + - lower_bounds_thresholds: np.ndarray of size (n_genes,2) + - n_samples: int + - trim_ratio: float + + Returns + ------- + dict + Dictionary with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes,2) + - lower_bounds_thresholds : np.ndarray of size (n_genes,2) + If use_lvl is true, the dictionary is nested with the levels as keys. + """ + trim_ratio = shared_states[0]["trim_ratio"] + upper_bounds_thresholds = shared_states[0]["upper_bounds_thresholds"] + lower_bounds_thresholds = shared_states[0]["lower_bounds_thresholds"] + + n_samples = np.sum([state["n_samples"] for state in shared_states]) + + n_trim = np.floor(n_samples * trim_ratio) + # Targets contain the number of samples we want to have above the two + # thresholds. + + targets = np.array([n_trim, n_samples - n_trim]) + + # We sum the number of samples above the thresholds for each gene. + agg_n_samples_strictly_above_quantiles = np.sum( + [state["num_strictly_above"] for state in shared_states], + axis=0, + ) + + # Mask of size (n_genes,2) indicating for each gene and each of the two + # thresholds if the number of samples above the threshold is too high. + mask_threshold_too_high = ( + agg_n_samples_strictly_above_quantiles < targets[None, :] + ) + + # Similarly, we create a mask for the case where the number of samples above the + # thresholds is too low. + mask_threshold_too_low = ( + agg_n_samples_strictly_above_quantiles > targets[None, :] + ) + + ## Update the thresholds and bounds when the thresholds are two high or too low. + upper_bounds_thresholds[mask_threshold_too_high] = ( + upper_bounds_thresholds[mask_threshold_too_high] + + lower_bounds_thresholds[mask_threshold_too_high] + ) / 2.0 + + lower_bounds_thresholds[mask_threshold_too_low] = ( + upper_bounds_thresholds[mask_threshold_too_low] + + lower_bounds_thresholds[mask_threshold_too_low] + ) / 2.0 + + return { + "upper_bounds_thresholds": upper_bounds_thresholds, + "lower_bounds_thresholds": lower_bounds_thresholds, + } + + +class LocFinalTrimmedMean: + """Mixin class to implement the local finalisation of the trimmed mean algo.""" + + local_adata: AnnData + refit_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def final_local_trimmed_mean( + self, + data_from_opener, + shared_state, + layer_used: str, + mode: Literal["normal", "cooks"] = "normal", + trim_ratio: float | None = None, + refit: bool = False, + ) -> dict: + """ + Finalise the trimmed mean algo by computing the trimmed mean. + + Parameters + ---------- + data_from_opener : ad.AnnData + Unused, all the necessary info is stored in the local adata. + + shared_state : dict + Dictionary with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes,2). Not used + - lower_bounds_thresholds : np.ndarray of size (n_genes,2). Not used + If use_lvl is true, the dictionary is nested with the levels as keys. + + layer_used : str + Name of the layer used to compute the trimmed mean. + + mode : Literal["normal", "cooks"] + Mode of the trimmed mean algo. If "cooks", the function will be applied + either on the normalized counts or the squared error. + It will be applied per level, except if there are not enough samples. + Moreover, trim ratios will be computed based on the number of replicates. + If "normal", the function will be applied on the whole dataset, using the + trim_ratio parameter. + + trim_ratio : float or None + Ratio of the samples to be trimmed. Must be between 0 and 0.5. Must be + None if mode is "cooks", and float if mode is "normal". + + + refit : bool + If true, the function will use the refit adata to compute the trimmed mean. + + Returns + ------- + dict + Dictionary with the following keys: + - trimmed_local_sum : np.ndarray(float) of size (n_genes,2) + - n_samples : np.ndarray(int) of size (n_genes,2) + - num_strictly_above : np.ndarray(int) of size (n_genes,2) + - upper_bounds_thresholds : np.ndarray of size (n_genes,2) + - lower_bounds_thresholds : np.ndarray of size (n_genes,2) + If use_lvl is true, the dictionary is nested with the levels as keys. + + """ + if refit: + adata = self.refit_adata + else: + adata = self.local_adata + + use_lvl = shared_state["use_lvl"] + result = {"use_lvl": use_lvl} + if mode == "cooks" and use_lvl: + for lvl in shared_state.keys(): + if lvl == "use_lvl": + continue + mask = adata.obs["cells"] == lvl + result[lvl] = self.final_local_trimmed_mean_per_lvl( + data_from_opener, shared_state[lvl], layer_used, mask, refit + ) + trim_ratio = get_trim_ratio(adata.uns["num_replicates"][lvl]) + if layer_used == "sqerror": + scale = get_scale(adata.uns["num_replicates"][lvl]) + result[lvl]["scale"] = scale + result[lvl]["trim_ratio"] = trim_ratio + return result + result.update( + self.final_local_trimmed_mean_per_lvl( + data_from_opener, + shared_state, + layer_used, + mask=np.ones(adata.n_obs, dtype=bool), + refit=refit, + ) + ) + if mode == "cooks": + assert trim_ratio is None + result["trim_ratio"] = 0.125 + if layer_used == "sqerror": + result["scale"] = 1.51 + else: + assert trim_ratio is not None + result["trim_ratio"] = trim_ratio + return result + + def final_local_trimmed_mean_per_lvl( + self, + data_from_opener, + shared_state, + layer_used, + mask, + refit: bool = False, + ) -> dict: + """ + Finalise the trimmed mean algo by computing the trimmed mean. + + Parameters + ---------- + data_from_opener : ad.AnnData + Unused, all the necessary info is stored in the local adata. + + shared_state : dict + Dictionary with the following keys: + - upper_bounds_thresholds : np.ndarray of size (n_genes,2). Not used + - lower_bounds_thresholds : np.ndarray of size (n_genes,2). Not used + + layer_used : str + Name of the layer used to compute the trimmed mean. + + mask : np.ndarray + Mask to filter values used in the quantile computation. + + refit : bool + If true, the function will use the refit adata to compute the trimmed mean. + + Returns + ------- + dict + Dictionary with the following keys: + - trimmed_local_sum : np.ndarray(float) of size (n_genes,2) + - n_samples : np.ndarray(int) of size (n_genes,2) + - num_strictly_above : np.ndarray(int) of size (n_genes,2) + - upper_bounds_thresholds : np.ndarray of size (n_genes,2) + - lower_bounds_thresholds : np.ndarray of size (n_genes,2) + + """ + if refit: + adata = self.refit_adata + else: + adata = self.local_adata + + # we create an explicit copy to avoid ImplicitModificationWarning + local_adata_filtered = adata[mask].copy() + current_thresholds = ( + shared_state["upper_bounds_thresholds"] + + shared_state["lower_bounds_thresholds"] + ) / 2.0 + + num_strictly_above = ( + local_adata_filtered.layers[layer_used][..., None] + > current_thresholds[None, ...] + ).sum(axis=0) + + mask_upper_threshold = ( + local_adata_filtered.layers[layer_used] + > current_thresholds[..., 0][None, :] + ) + mask_lower_threshold = ( + local_adata_filtered.layers[layer_used] + <= current_thresholds[..., 1][None, :] + ) + local_adata_filtered.layers[ + f"trimmed_{layer_used}" + ] = local_adata_filtered.layers[layer_used].copy() + local_adata_filtered.layers[f"trimmed_{layer_used}"][ + mask_upper_threshold | mask_lower_threshold + ] = 0 + + return { + "trimmed_local_sum": local_adata_filtered.layers[ + f"trimmed_{layer_used}" + ].sum(axis=0), + "n_samples": local_adata_filtered.n_obs, + "num_strictly_above": num_strictly_above, + "upper_bounds_thresholds": shared_state["upper_bounds_thresholds"], + "lower_bounds_thresholds": shared_state["lower_bounds_thresholds"], + } + + +class AggFinalTrimmedMean: + """Mixin class of the aggregation of the finalisation of the trimmed mean algo.""" + + @remote + @log_remote + def final_agg_trimmed_mean( + self, + shared_states: list[dict], + layer_used: str, + mode: Literal["normal", "cooks"] = "normal", + ) -> dict: + """ + Compute the initial global upper and lower bounds. + + Parameters + ---------- + shared_states : list[dict] + List of dictionnaries with the following keys: + - trimmed_local_sum : np.ndarray(float) of size (n_genes,2) + - n_samples : np.ndarray(int) of size (n_genes,2) + - num_strictly_above : np.ndarray(int) of size (n_genes,2) + - upper_bounds_thresholds : np.ndarray of size (n_genes,2) + - lower_bounds_thresholds : np.ndarray of size (n_genes,2) + If use_lvl is true, the dictionary is nested with the levels as keys. + + layer_used : str + Name of the layer used to compute the trimmed mean. + + mode : Literal["normal", "cooks"] + Mode of the trimmed mean algo. If "cooks", the function will be applied + either on the normalized counts or the squared error. + It will be applied per level, except if there are not enough samples. + Moreover, trim ratios will be computed based on the number of replicates. + If "normal", the function will be applied on the whole dataset, using the + trim_ratio parameter. + + + Returns + ------- + dict + If mode is "cooks" and if the layer is "sqerror", a dictionary with the + "varEst" key containing + - The maximum of the trimmed means per level if use_level is true, + rescaled by a scale factor depending on the number of replicates + - The trimmed mean of the whole dataset otherwise + scaled by 1.51. + else, if mode is cooks and use_lvl is true, a dictionary with a + trimmed_mean_normed_counts key containing a dataframe + with the trimmed means per level, levels being columns + else, a dictionary with the following keys: + - trimmed_mean_layer_used : np.ndarray(float) of size (n_genes) + + + """ + use_lvl = shared_states[0]["use_lvl"] + if mode == "cooks" and use_lvl: + result = {} + for lvl in shared_states[0].keys(): + if lvl == "use_lvl": + continue + result[lvl] = self.final_agg_trimmed_mean_per_lvl( + [state[lvl] for state in shared_states], layer_used + )[f"trimmed_mean_{layer_used}"] + if layer_used == "sqerror": + return {"varEst": pd.DataFrame.from_dict(result).max(axis=1).to_numpy()} + else: + return {f"trimmed_mean_{layer_used}": pd.DataFrame.from_dict(result)} + elif mode == "cooks" and layer_used == "sqerror": + return { + "varEst": self.final_agg_trimmed_mean_per_lvl( + shared_states, layer_used + )["trimmed_mean_sqerror"] + } + return self.final_agg_trimmed_mean_per_lvl(shared_states, layer_used) + + def final_agg_trimmed_mean_per_lvl( + self, + shared_states: list[dict], + layer_used: str, + ) -> dict: + """ + Aggregate step of the finalisation of the trimmed mean algo. + + Parameters + ---------- + shared_states : list[dict] + List of dictionary containing the following keys: + - trimmed_local_sum : np.ndarray(float) of size (n_genes,2) + - n_samples : np.ndarray(int) of size (n_genes,2) + - num_strictly_above : np.ndarray(int) of size (n_genes,2) + - upper_bounds : np.ndarray of size (n_genes,2) + - lower_bounds : np.ndarray of size (n_genes,2) + + layer_used : str + Name of the layer used to compute the trimmed mean. + + + Returns + ------- + dict + Dictionary with the following keys: + - trimmed_mean_layer_used : np.ndarray(float) of size (n_genes) + + """ + trim_ratio = shared_states[0]["trim_ratio"] + n_samples = np.sum([state["n_samples"] for state in shared_states]) + agg_n_samples_strictly_above_quantiles = np.sum( + [state["num_strictly_above"] for state in shared_states], + axis=0, + ) + n_trim = np.floor(n_samples * trim_ratio) + targets = np.array([n_trim, n_samples - n_trim]) + effective_n_samples = n_samples - 2 * n_trim + trimmed_sum = np.sum( + [state["trimmed_local_sum"] for state in shared_states], axis=0 + ) + current_thresholds = ( + shared_states[0]["upper_bounds_thresholds"] + + shared_states[0]["lower_bounds_thresholds"] + ) / 2.0 + + # The following lines deal with the "tie" cases, i.e. where duplicate values + # fall on both part of the "n_trimmed" position. In that case, + # agg_n_samples_strictly_above_quantiles is different from target. + # "delta_sample_above_quantile" encode how many elements were wrongly + # trimmed or not trimmed. We know that these elements were close to the + # values of the threshold up to ~2^{-n_iter} precision. We can then correct the + # trimmed sum easily using the threshold values. + + delta_sample_above_quantile = ( + agg_n_samples_strictly_above_quantiles - targets[None, :] + ) + trimmed_sum = ( + trimmed_sum + + delta_sample_above_quantile[..., 0] * current_thresholds[..., 0] + ) + trimmed_sum = ( + trimmed_sum + - delta_sample_above_quantile[..., 1] * current_thresholds[..., 1] + ) + trimmed_mean = trimmed_sum / effective_n_samples + if "scale" in shared_states[0].keys(): + scale = shared_states[0]["scale"] + trimmed_mean = trimmed_mean * scale + return {f"trimmed_mean_{layer_used}": trimmed_mean} diff --git a/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/utils.py b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/utils.py new file mode 100644 index 0000000..3e18f7a --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/compute_trimmed_mean/utils.py @@ -0,0 +1,52 @@ +def trimfn(x: float) -> int: + """ + Determine the use-case of the trim ratio and scale based on cell counts. + + Parameters + ---------- + x : float + The number of cells. + + Returns + ------- + int + The index of the trim ratio and scale to use. + """ + return 2 if x >= 23.5 else 1 if x >= 3.5 else 0 + + +def get_trim_ratio(x): + """ + Get the trim ratio based on the number of cells. + + Parameters + ---------- + x : float + The number of cells. + + Returns + ------- + float + The trim ratio. + """ + trimratio = (1 / 3, 1 / 4, 1 / 8) + return trimratio[trimfn(x)] + + +def get_scale(x): + """ + Get the scale based on the number of cells. + + Parameters + ---------- + x : float + The number of cells. + + Returns + ------- + float + The scale used to compute the dispersion during cook distance calculation. + + """ + scales = (2.04, 1.86, 1.51) + return scales[trimfn(x)] diff --git a/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/__init__.py b/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/__init__.py new file mode 100644 index 0000000..18a40df --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/__init__.py @@ -0,0 +1,3 @@ +from fedpydeseq2.core.fed_algorithms.dispersions_grid_search.dispersions_grid_search import ( # noqa: E501 + ComputeDispersionsGridSearch, +) diff --git a/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/dispersions_grid_search.py b/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/dispersions_grid_search.py new file mode 100644 index 0000000..d5a44b4 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/dispersions_grid_search.py @@ -0,0 +1,136 @@ +"""Main module to compute dispersions by minimizing the MLE using a grid search.""" + +from typing import Literal + +from fedpydeseq2.core.fed_algorithms.dispersions_grid_search.substeps import ( + AggGridUpdate, +) +from fedpydeseq2.core.fed_algorithms.dispersions_grid_search.substeps import LocGridLoss +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class ComputeDispersionsGridSearch( + AggGridUpdate, + LocGridLoss, +): + """ + Mixin class to implement the computation of genewise dispersions. + + The switch between genewise and MAP dispersions is done by setting the `fit_mode` + argument in the `fit_dispersions` to either "MLE" or "MAP". + + Methods + ------- + fit_dispersions + A method to fit dispersions using grid search. + + """ + + grid_batch_size: int + grid_depth: int + grid_length: int + + def fit_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + shared_state, + round_idx, + clean_models, + fit_mode: Literal["MLE", "MAP"] = "MLE", + refit_mode: bool = False, + ): + """Fit dispersions using grid search. + + Supports two modes: "MLE", to fit gene-wise dispersions, and "MAP", to fit + MAP dispersions and filter them to avoid shrinking the dispersions of genes + that are too far from the trend curve. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + shared_state: dict or None + If the fit_mode is "MLE", it is None. + If the fit_mode is "MAP", it contains the output of the trend fitting, + that is a dictionary with a "fitted_dispersion" field containing + the fitted dispersions from the trend curve, a "prior_disp_var" field + containing the prior variance of the dispersions, and a "_squared_logres" + field containing the squared residuals of the trend fitting. + + + round_idx: int + The current round. + + clean_models: bool + Whether to clean the models after the computation. + + fit_mode: str + If "MLE", gene-wise dispersions are fitted independently, and + `"genewise_dispersions"` fields are populated. If "MAP", prior + regularization is applied, `"MAP_dispersions"` fields are populated. + + refit_mode: bool + Whether to run on `refit_adata`s instead of `local_adata`s (default: False). + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + shared_state: dict or list[dict] + A dictionary containing: + - "genewise_dispersions": The MLE dispersions, to be stored locally at + - "lower_log_bounds": log lower bounds for the grid search (only used in + internal loop), + - "upper_log_bounds": log upper bounds for the grid search (only used in + internal loop). + + round_idx: int + The updated round index. + """ + for _ in range(self.grid_depth): + # Compute local loss summands at all grid points. + local_states, shared_states, round_idx = local_step( + local_method=self.local_grid_loss, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute local grid loss summands.", + round_idx=round_idx, + clean_models=clean_models, + method_params={ + "prior_reg": fit_mode == "MAP", + "refit_mode": refit_mode, + }, + ) + + # Aggregate local summands and refine the search interval. + shared_state, round_idx = aggregation_step( + aggregation_method=self.global_grid_update, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Perform a global grid search update.", + round_idx=round_idx, + clean_models=clean_models, + method_params={ + "prior_reg": fit_mode == "MAP", + "dispersions_param_name": "genewise_dispersions" + if fit_mode == "MLE" + else "MAP_dispersions", + }, + ) + + return local_states, shared_state, round_idx diff --git a/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/substeps.py b/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/substeps.py new file mode 100644 index 0000000..6ef1a67 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/dispersions_grid_search/substeps.py @@ -0,0 +1,254 @@ +"""Module to implement the substeps to fit dispersions with MLE. + +This module contains all the substeps to fit dispersions using a grid search. +""" + + +import numpy as np +from anndata import AnnData +from joblib import Parallel # type: ignore +from joblib import delayed +from joblib import parallel_backend +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils import global_grid_cr_loss +from fedpydeseq2.core.utils import local_grid_summands +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocGridLoss: + """Mixin to compute local MLE summands on a grid.""" + + local_adata: AnnData + refit_adata: AnnData + grid_batch_size: int + grid_length: int + min_disp: float + num_jobs: int + joblib_backend: str + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_grid_loss( + self, + data_from_opener, + shared_state, + prior_reg: bool = False, + refit_mode: bool = False, + ) -> dict: + """ + Compute local MLE losses and Cox-Reid summands on a grid. + + Parameters + ---------- + data_from_opener : ad.AnnData + Not used. + + shared_state : dict, optional + Shared states with the previous search intervals "lower_log_bounds" and + "upper_log_bounds", except at initial step where it is None in the case + of gene-wise dispersions, or contains the output of the trend fitting + in the case of MAP dispersions. + + prior_reg : bool + Whether to include prior regularization, for MAP estimation + (default: False). + + refit_mode : bool + Whether to run on `refit_adata`s instead of `local_adata`s (default: False). + + Returns + ------- + dict + Keys: + - "nll": local negative log-likelihoods (n_genes x grid_length), + - "CR_summand": local Cox-Reid adjustment summands + (n_params x n_params x n_genes x grid_length), + - "grid": grid of dispersions to evaluate (n_genes x grid_length), + - "n_samples": number of samples in the local dataset, + - "max_disp": global upper bound on dispersions. + - "non_zero": mask of all zero genes. + - "reg": quadratic regularization term for MAP estimation (only if + `prior_reg=True`). + """ + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + # If we are fitting MAP dispersions and this is the first iteration, we need + # to save the results of the trend curve fitting. + # In refit mode, we can use the results from the previous iteration. + if not refit_mode: + if prior_reg and ("fitted_dispersions" not in self.local_adata.varm): + self.local_adata.varm["fitted_dispersions"] = shared_state[ + "fitted_dispersions" + ] + self.local_adata.uns["trend_coeffs"] = shared_state["trend_coeffs"] + self.local_adata.uns["prior_disp_var"] = shared_state["prior_disp_var"] + self.local_adata.uns["_squared_logres"] = shared_state[ + "_squared_logres" + ] + self.local_adata.uns["disp_function_type"] = shared_state[ + "disp_function_type" + ] + self.local_adata.uns["mean_disp"] = shared_state["mean_disp"] + + # Compute log space grids + if (shared_state is not None) and ("lower_log_bounds" in shared_state): + # Get the bounds from the previous iteration. Each gene has its own bounds. + min_log_alpha = shared_state["lower_log_bounds"] # ndarray (n_genes) + max_log_alpha = shared_state["upper_log_bounds"] # ndarray (n_genes) + grid = np.exp(np.linspace(min_log_alpha, max_log_alpha, self.grid_length)).T + # of size n_genes x grid_length + else: + # At first iteration, all genes get the same grid + min_log_alpha = np.log(self.min_disp) # float + max_log_alpha = np.log(adata.uns["max_disp"]) # float + grid = np.exp(np.linspace(min_log_alpha, max_log_alpha, self.grid_length)) + # of size n_genes x grid_length + grid = np.repeat(grid[None, :], adata.n_vars, axis=0) + + design = adata.obsm["design_matrix"].values + n_params = design.shape[1] + + with parallel_backend(self.joblib_backend): + res = Parallel( + n_jobs=self.num_jobs, + )( + delayed(local_grid_summands)( + counts=adata.X[:, i : i + self.grid_batch_size], + design=design, + mu=adata.layers["_mu_hat"][:, i : i + self.grid_batch_size], + alpha_grid=grid[i : i + self.grid_batch_size, :], + ) + for i in range(0, adata.n_vars, self.grid_batch_size) + ) + if len(res) == 0: + nll = np.zeros((0, self.grid_length)) + CR_summand = np.zeros( + (0, self.grid_length, n_params, n_params), + ) + else: + nll = np.vstack([x[0] for x in res]) + CR_summand = np.vstack([x[1] for x in res]) + + result_shared_state = { + "nll": nll, + "CR_summand": CR_summand, + "grid": grid, + "max_disp": adata.uns["max_disp"], + "non_zero": adata.varm["non_zero"], + } + + if prior_reg: + reg = ( + np.log(grid) - np.log(adata.varm["fitted_dispersions"])[:, None] + ) ** 2 / (2 * adata.uns["prior_disp_var"]) + + result_shared_state["reg"] = reg + + return result_shared_state + + +class AggGridUpdate: + """Mixin to compute global MLE grid updates.""" + + min_disp: float + grid_batch_size: int + num_jobs: int + joblib_backend: str + + @remote + @log_remote + def global_grid_update( + self, + shared_states, + prior_reg: bool = False, + dispersions_param_name: str = "genewise_dispersions", + ) -> dict: + """Aggregate local MLE summands on a grid and update global dispersion. + + Also sets new search intervals for recursion. + + Parameters + ---------- + shared_states : list + List of local states dictionaries, with: + - "nll": local negative log-likelihoods (n_genes x grid_length), + - "CR_summand": local Cox-Reid adjustment summands + (n_params x n_params x n_genes x grid_length), + - "grid": grid of dispersions that were evaluated (n_genes x grid_length), + - "max_disp": global upper bound on dispersions. + - "reg": prior regularization to add for MAP dispersions + (only if prior_reg is True). + + prior_reg : bool + Whether to include prior regularization, for MAP estimation + (default: False). + + + dispersions_param_name : str + Name of the dispersion parameter to update. Dispersions will be saved under + this name. (default: "genewise_dispersions"). + + Returns + ------- + dict + Keys: + - dispersions_param_name: updated dispersions (n_genes), + - "lower_log_bounds": updated lower log bounds (n_genes), + - "upper_log_bounds": updated upper log bounds (n_genes). + """ + nll = sum([state["nll"] for state in shared_states]) + global_CR_summand = sum([state["CR_summand"] for state in shared_states]) + + # Compute (batched) global losses + with parallel_backend(self.joblib_backend): + res = Parallel( + n_jobs=self.num_jobs, + )( + delayed(global_grid_cr_loss)( + nll=nll[i : i + self.grid_batch_size], + cr_grid=global_CR_summand[i : i + self.grid_batch_size], + ) + for i in range(0, len(nll), self.grid_batch_size) + ) + + if len(res) == 0: + global_losses = np.zeros((0, nll.shape[1])) + else: + global_losses = np.concatenate(res, axis=0) + + if prior_reg: + global_losses += shared_states[0]["reg"] + + # For each gene, find the argmin alpha, and the new search interval + grids = shared_states[0]["grid"] + # min_idx of shape n_genes + min_idx = np.argmin(global_losses, axis=1) + # delta of shape n_genes + alpha = grids[np.arange(len(grids)), min_idx] + + # Compute the new bounds + # Note: the grid should be in log space + delta_grid = np.log(grids[:, 1]) - np.log(grids[:, 0]) + log_grid_lower_bounds = np.maximum( + np.log(self.min_disp), np.log(alpha) - delta_grid + ) + log_grid_upper_bounds = np.minimum( + np.log(shared_states[0]["max_disp"]), np.log(alpha) + delta_grid + ) + + # Set the dispersions of all-zero genes to NaN + alpha[~shared_states[0]["non_zero"]] = np.NaN + + return { + dispersions_param_name: alpha, + "lower_log_bounds": log_grid_lower_bounds, + "upper_log_bounds": log_grid_upper_bounds, + } diff --git a/fedpydeseq2/core/fed_algorithms/fed_PQN/__init__.py b/fedpydeseq2/core/fed_algorithms/fed_PQN/__init__.py new file mode 100644 index 0000000..01846d3 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_PQN/__init__.py @@ -0,0 +1,3 @@ +"""Necessary mixin and utils for prox newton method.""" + +from fedpydeseq2.core.fed_algorithms.fed_PQN.fed_PQN import FedProxQuasiNewton diff --git a/fedpydeseq2/core/fed_algorithms/fed_PQN/fed_PQN.py b/fedpydeseq2/core/fed_algorithms/fed_PQN/fed_PQN.py new file mode 100644 index 0000000..7d12ed0 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_PQN/fed_PQN.py @@ -0,0 +1,133 @@ +from typing import Literal + +from fedpydeseq2.core.fed_algorithms.fed_PQN.substeps import ( + AggChooseStepComputeAscentDirection, +) +from fedpydeseq2.core.fed_algorithms.fed_PQN.substeps import ( + LocMakeFedPQNFisherGradientNLL, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class FedProxQuasiNewton( + LocMakeFedPQNFisherGradientNLL, AggChooseStepComputeAscentDirection +): + """Mixin class to implement a Prox Newton method for box constraints. + + It implements the method presented here: + https://www.cs.utexas.edu/~inderjit/public_papers/pqnj_sisc10.pdf + More context can be found here + https://optml.mit.edu/papers/sksChap.pdf + + Methods + ------- + run_fed_PQN + The method to run the Prox Quasi Newton algorithm. + It relies on the methods inherited from the LocMakeFedPQNFisherGradientNLL and + AggChooseStepComputeAscentDirection classes. + + """ + + PQN_num_iters: int + + def run_fed_PQN( + self, + train_data_nodes, + aggregation_node, + local_states, + PQN_shared_state, + first_iteration_mode: Literal["irls_catch"] | None, + round_idx, + clean_models, + refit_mode: bool = False, + ): + """Run the Prox Quasi Newton algorithm. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + PQN_shared_state: dict + The input shared state. + The requirements for this shared state are defined in the + LocMakeFedPQNFisherGradientNLL class and depend on the + first_iteration_mode. + + first_iteration_mode: Optional[Literal["irls_catch"]] + The first iteration mode. + This defines the input requirements for the algorithm, and is passed + to the make_local_fisher_gradient_nll method at the first iteration. + + round_idx: int + The current round. + + clean_models: bool + If True, the models are cleaned. + + refit_mode: bool + Whether to run on `refit_adata`s instead of `local_adata`s. + (default: False). + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + irls_final_shared_states: dict + Shared states containing the final IRLS results. + It contains nothing for now. + + round_idx: int + The updated round index. + + """ + #### ---- Main training loop ---- ##### + + for pqn_iter in range(self.PQN_num_iters + 1): + # ---- Compute local IRLS summands and nlls ---- # + + ( + local_states, + local_fisher_gradient_nlls_shared_states, + round_idx, + ) = local_step( + local_method=self.make_local_fisher_gradient_nll, + method_params={ + "first_iteration_mode": first_iteration_mode + if pqn_iter == 0 + else None, + "refit_mode": refit_mode, + }, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=PQN_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute local Prox Newton summands and nlls.", + clean_models=clean_models, + ) + + # ---- Compute global IRLS update and nlls ---- # + + PQN_shared_state, round_idx = aggregation_step( + aggregation_method=self.choose_step_and_compute_ascent_direction, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=local_fisher_gradient_nlls_shared_states, + round_idx=round_idx, + description="Update the log fold changes and nlls in IRLS.", + clean_models=clean_models, + ) + + #### ---- End of training ---- #### + + return local_states, PQN_shared_state, round_idx diff --git a/fedpydeseq2/core/fed_algorithms/fed_PQN/substeps.py b/fedpydeseq2/core/fed_algorithms/fed_PQN/substeps.py new file mode 100644 index 0000000..bc2a034 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_PQN/substeps.py @@ -0,0 +1,742 @@ +from typing import Any +from typing import Literal + +import numpy as np +from anndata import AnnData +from joblib import Parallel +from joblib import delayed +from joblib import parallel_backend +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms.fed_PQN.utils import ( + compute_ascent_direction_decrement, +) +from fedpydeseq2.core.fed_algorithms.fed_PQN.utils import ( + compute_gradient_scaling_matrix_fisher, +) +from fedpydeseq2.core.fed_algorithms.fed_PQN.utils import ( + make_fisher_gradient_nll_step_sizes_batch, +) +from fedpydeseq2.core.utils.compute_lfc_utils import get_lfc_utils_from_gene_mask_adata +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocMakeFedPQNFisherGradientNLL: + """Mixin to compute local values, gradient and Fisher information of the NLL. + + Attributes + ---------- + local_adata : AnnData + The local AnnData. + num_jobs : int + The number of cpus to use. + joblib_verbosity : int + The joblib verbosity. + joblib_backend : str + The backend to use for the IRLS algorithm. + irls_batch_size : int + The batch size to use for the IRLS algorithm. + max_beta : float + The maximum value for the beta parameter. + PQN_num_iters_ls : int + The number of iterations to use for the line search. + PQN_min_mu : float + The min_mu parameter for the Proximal Quasi Newton algorithm. + + Methods + ------- + make_local_fisher_gradient_nll + A remote_data method. + Make the local nll, gradient and fisher matrix. + + """ + + local_adata: AnnData + refit_adata: AnnData + num_jobs: int + joblib_verbosity: int + joblib_backend: str + irls_batch_size: int + max_beta: float + PQN_num_iters_ls: int + PQN_min_mu: float + + @remote_data + @log_remote_data + @reconstruct_adatas + def make_local_fisher_gradient_nll( + self, + data_from_opener: AnnData, + shared_state: dict[str, Any], + first_iteration_mode: Literal["irls_catch"] | None = None, + refit_mode: bool = False, + ): + r"""Make the local nll, gradient and fisher information matrix. + + Given an ascent direction :math:`d` (an ascent direction being positively + correlated to the gradient of the starting point) and a starting point + :math:`beta`, this function + computes the nll, gradient and Fisher information at the points + :math:`beta + t * d`, + for :math:`t` in step_sizes + (step sizes are :math:`0.5^i` for :math:`i` in :math:`0,...,19`. + + + Moreover, if the iteration is the first one, the step sizes are not used, + and instead, the nll, gradient and fisher information are computed at the + current beta values. + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + A dictionary containing the following + keys: + - PQN_mask: ndarray + A boolean mask indicating if the gene should be used for the + proximal newton step. + It is of shape (n_non_zero_genes,) + Used, but not modified. + - round_number_PQN: int + The current round number of the prox newton algorithm. + Used but not modified. + - ascent_direction_on_mask: Optional[ndarray] + The ascent direction, of shape (n_genes, n_params), where + n_genes is the current number of genes that are active (True + in the PQN_mask). + Used but not modified. + - beta: ndarray + The log fold changes, of shape (n_non_zero_genes, n_params). + Used but not modified. + - global_reg_nll: ndarray + The global regularized nll, of shape (n_non_zero_genes,). + Not used and not modified. + - newton_decrement_on_mask: Optional[ndarray] + The newton decrement, of shape (n_ngenes,). + It is None at the first round of the prox newton algorithm. + Not used and not modified. + - irls_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the IRLS + algorithm. + Not used and not modified. + - PQN_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the prox newton + algorithm. + Not used and not modified. + + first_iteration_mode : Optional[Literal["irls_catch"]] + For the first iteration, this function behaves differently. If + first_iteration_mode is None, then we are not at the first iteration. + If first_iteration_mode is not None, the function will expect a + different shared state than the one described above, and will construct + the initial shared state from it. + If first_iteration_mode is "irls_catch", then we assume that + we are using the PQN algorithm as a method to catch IRLS when it fails + The function will expect a + shared state that contains the following fields: + - beta: ndarray + The log fold changes, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the IRLS + algorithm. + - irls_mask : ndarray + The mask of genes that were still active for the IRLS algorithm. + + refit_mode : bool + Whether to run on `refit_adata`s instead of `local_adata`s. + (default: False). + + Returns + ------- + dict + The state to share to the server. + It contains the following fields: + - beta: ndarray + The log fold changes, of shape (n_non_zero_genes, n_params). + - local_nll: ndarray + The local nll, of shape (n_step_sizes, n_genes,), where + n_genes is the current number of genes that are active (True + in the PQN_mask). n_step_sizes is the number of step sizes + considered, which is `PQN_num_iters_ls` if we are not at the + first round, and 1 otherwise. + This is created during this step. + - local_fisher: ndarray + The local fisher matrix, + of shape (n_step_sizes, n_genes, n_params, n_params). + This is created during this step. + - local_gradient: ndarray + The local gradient, of shape (n_step_sizes, n_genes, n_params). + This is created during this step. + - PQN_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the prox newton + algorithm. + - PQN_mask: ndarray + A boolean mask indicating if the gene should be used for the + proximal newton step, of shape (n_non_zero_genes,). + - global_reg_nll: ndarray + The global regularized nll, of shape (n_non_zero_genes,). + - newton_decrement_on_mask: Optional[ndarray] + The newton decrement, of shape (n_ngenes,). + This is None at the first round of the prox newton algorithm. + - round_number_PQN: int + The current round number of the prox newton algorithm. + - irls_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the IRLS + algorithm. + - ascent_direction_on_mask: Optional[ndarray] + The ascent direction, of shape (n_genes, n_params), where + n_genes is the current number of genes that are active (True + in the PQN_mask). + + Raises + ------ + ValueError + If first_iteration_mode is not None or "irls_catch". + """ + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + # Distinguish between the first iteration and the rest + if first_iteration_mode is not None and first_iteration_mode == "irls_catch": + beta = shared_state["beta"] + irls_diverged_mask = shared_state["irls_diverged_mask"] + irls_mask = shared_state["irls_mask"] + PQN_mask = irls_mask | irls_diverged_mask + irls_diverged_mask = PQN_mask.copy() + round_number_PQN = 0 + ascent_direction_on_mask = None + newton_decrement_on_mask = None + PQN_diverged_mask = np.zeros_like(irls_mask, dtype=bool) + global_reg_nll = np.nan * np.ones_like(irls_mask, dtype=float) + elif first_iteration_mode is None: + # If we are not at the first iteration, we use the shared state + PQN_mask = shared_state["PQN_mask"] + round_number_PQN = shared_state["round_number_PQN"] + ascent_direction_on_mask = shared_state["ascent_direction_on_mask"] + beta = shared_state["beta"] + PQN_diverged_mask = shared_state["PQN_diverged_mask"] + newton_decrement_on_mask = shared_state["newton_decrement_on_mask"] + global_reg_nll = shared_state["global_reg_nll"] + irls_diverged_mask = shared_state["irls_diverged_mask"] + else: + raise ValueError("first_iteration_mode should be None or 'irls_catch'") + + if round_number_PQN == 0: + # Sanity check that this is the first round of fed prox + beta[PQN_mask] = adata.uns["_irls_beta_init"][PQN_mask] + step_sizes: np.ndarray | None = None + + else: + step_sizes = 0.5 ** np.arange(self.PQN_num_iters_ls) + + # Get the quantities stored in the adata + disp_param_name = adata.uns["_irls_disp_param_name"] + + ( + PQN_gene_names, + design_matrix, + size_factors, + counts, + dispersions, + beta_on_mask, + ) = get_lfc_utils_from_gene_mask_adata( + adata, + PQN_mask, + disp_param_name, + beta=beta, + ) + + # ---- Compute local nll, gradient and Fisher information ---- # + + with parallel_backend(self.joblib_backend): + res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)( + delayed(make_fisher_gradient_nll_step_sizes_batch)( + design_matrix=design_matrix, + size_factors=size_factors, + beta=beta_on_mask[i : i + self.irls_batch_size], + dispersions=dispersions[i : i + self.irls_batch_size], + counts=counts[:, i : i + self.irls_batch_size], + ascent_direction=ascent_direction_on_mask[ + i : i + self.irls_batch_size + ] + if ascent_direction_on_mask is not None + else None, + step_sizes=step_sizes, + beta_min=-self.max_beta, + beta_max=self.max_beta, + min_mu=self.PQN_min_mu, + ) + for i in range(0, len(beta_on_mask), self.irls_batch_size) + ) + + n_step_sizes = len(step_sizes) if step_sizes is not None else 1 + if len(res) == 0: + H = np.zeros((n_step_sizes, 0, beta.shape[1], beta.shape[1])) + gradient = np.zeros((n_step_sizes, 0, beta.shape[1])) + local_nll = np.zeros((n_step_sizes, 0)) + else: + H = np.concatenate([r[0] for r in res], axis=1) + gradient = np.concatenate([r[1] for r in res], axis=1) + local_nll = np.concatenate([r[2] for r in res], axis=1) + + # Create the shared state + return { + "beta": beta, + "local_nll": local_nll, + "local_fisher": H, + "local_gradient": gradient, + "PQN_diverged_mask": PQN_diverged_mask, + "PQN_mask": PQN_mask, + "global_reg_nll": global_reg_nll, + "newton_decrement_on_mask": newton_decrement_on_mask, + "round_number_PQN": round_number_PQN, + "irls_diverged_mask": irls_diverged_mask, + "ascent_direction_on_mask": ascent_direction_on_mask, + } + + +class AggChooseStepComputeAscentDirection: + """Mixin class to compute the right ascent direction. + + An ascent direction is a direction that is positively correlated to the gradient. + This direction will be used to compute the next iterate in the proximal quasi newton + algorithm. As our aim will be to mimimize the negative log likelihood, we will + move in the opposite direction, that is in the direction of minus the + ascent direction. + + Attributes + ---------- + num_jobs : int + The number of cpus to use. + joblib_verbosity : int + The joblib verbosity. + joblib_backend : str + The backend to use for the IRLS algorithm. + irls_batch_size : int + The batch size to use for the IRLS algorithm. + max_beta : float + The maximum value for the beta parameter. + beta_tol : float + The tolerance for the beta parameter. + PQN_num_iters_ls : int + The number of iterations to use for the line search. + PQN_c1 : float + The c1 parameter for the line search. + PQN_ftol : float + The ftol parameter for the line search. + PQN_num_iters : int + The number of iterations to use for the proximal quasi newton algorithm. + + Methods + ------- + choose_step_and_compute_ascent_direction + A remote method. + Choose the best step size and compute the next ascent direction. + """ + + num_jobs: int + joblib_verbosity: int + joblib_backend: str + irls_batch_size: int + max_beta: float + beta_tol: float + PQN_num_iters_ls: int + PQN_c1: float + PQN_ftol: float + PQN_num_iters: int + + @remote + @log_remote + def choose_step_and_compute_ascent_direction( + self, shared_states: list[dict] + ) -> dict[str, Any]: + """Choose best step size and compute next ascent direction. + + By "ascent direction", we mean the direction that is positively correlated + with the gradient. + + The role of this function is twofold. + + 1) It chooses the best step size for each gene, and updates the beta values + as well as the nll values. This allows to define the next iterate. + Note that at the first iterate, it simply computes the nll, gradient and fisher + information at the current beta values, to define the next ascent direction. + + 2) For this new iterate (or the current one if we are at the first round), + it computes the gradient scaling matrix, which is used to scale the gradient + in the proximal newton algorithm. From this gradient scaling matrix, and the + gradient, it computes the ascent direction (and the newton decrement). + + + Parameters + ---------- + shared_states: list[dict] + A list of dictionaries containing the following + keys: + - beta: ndarray + The log fold changes, of shape (n_non_zero_genes, n_params). + - local_nll: ndarray + The local nll, of shape (n_genes,), where + n_genes is the current number of genes that are active (True + in the PQN_mask). + - local_fisher: ndarray + The local fisher matrix, of shape (n_genes, n_params, n_params). + - local_gradient: ndarray + The local gradient, of shape (n_genes, n_params). + - PQN_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the prox newton + algorithm, of shape (n_non_zero_genes,). + - PQN_mask: ndarray + A boolean mask indicating if the gene should be used for the + proximal newton step, of shape (n_non_zero_genes,). + - global_reg_nll: ndarray + The global regularized nll, of shape (n_non_zero_genes,). + - newton_decrement_on_mask: Optional[ndarray] + The newton decrement, of shape (n_ngenes,). + This is None at the first round of the prox newton algorithm. + - round_number_PQN: int + The current round number of the prox newton algorithm. + - ascent_direction_on_mask: Optional[ndarray] + The ascent direction, of shape (n_genes, n_params), where + n_genes is the current number of genes that are active (True + in the PQN_mask). + - irls_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the IRLS + algorithm. + + Returns + ------- + dict[str, Any] + A dictionary containing all the necessary info to run the method. + If we are not at the last iteration, it contains the following fields: + - beta: ndarray + The log fold changes, of shape (n_non_zero_genes, n_params). + - PQN_mask: ndarray + A boolean mask indicating if the gene should be used for the + proximal newton step. + It is of shape (n_non_zero_genes,) + - PQN_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the prox newton + algorithm. It is of shape (n_non_zero_genes,) + - ascent_direction_on_mask: np.ndarray + The ascent direction, of shape (n_genes, n_params), where + n_genes is the current number of genes that are active (True + in the PQN_mask). + - newton_decrement_on_mask: np.ndarray + The newton decrement, of shape (n_ngenes,). + - round_number_PQN: int + The current round number of the prox newton algorithm. + - global_reg_nll: ndarray + The global regularized nll, of shape (n_non_zero_genes,). + - irls_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the IRLS + algorithm. + If we are at the last iteration, it contains the following fields: + - beta: ndarray + The log fold changes, of shape (n_non_zero_genes, n_params). + - PQN_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the prox newton + algorithm. It is of shape (n_non_zero_genes,) + - irls_diverged_mask: ndarray + A boolean mask indicating if the gene has diverged in the IRLS + algorithm. + + """ + # Wwe use the following naming convention: when we say "on mask", we mean + # that we restrict the quantity to the genes that are active in the proximal + # newton + # algorithm. We therefore need to ensure that these quantities are readjusted + # when we change the proximal quasi newton mask. + + # Load main params from the first state + beta = shared_states[0]["beta"] + PQN_diverged_mask = shared_states[0]["PQN_diverged_mask"] + PQN_mask = shared_states[0]["PQN_mask"] + reg_nll = shared_states[0]["global_reg_nll"] + ascent_direction_on_mask = shared_states[0]["ascent_direction_on_mask"] + newton_decrement_on_mask = shared_states[0]["newton_decrement_on_mask"] + round_number_PQN = shared_states[0]["round_number_PQN"] + + reg_parameter = 1e-6 + + # ---- Step 0: Aggregate the nll, gradient and fisher info ---- # + + new_fisher_options_on_mask = sum( + [state["local_fisher"] for state in shared_states] + ) + + new_gradient_options_on_mask = sum( + [state["local_gradient"] for state in shared_states] + ) + new_reg_nll_options_on_mask = sum( + [state["local_nll"] for state in shared_states] + ) + + # ---- Step 1: Add the regularization term ---- # + + # ---- Step 1a: Compute the new beta options ---- # + + # In order to regularize, we have to compute the beta values at which + # the nll, gradient and fisher informations were evaluated in the local steps. + + beta_on_mask = beta[PQN_mask] + + if round_number_PQN == 0: + # In this case, there is no line search, and only + # beta is considered in the local steps. + new_beta_options_on_mask = beta_on_mask[None, :] + + else: + # In this case, there is a line search, and we have to + # compute the new beta options + assert ascent_direction_on_mask is not None + step_sizes = 0.5 ** np.arange(self.PQN_num_iters_ls) + new_beta_options_on_mask = np.clip( + beta_on_mask[None, :, :] + - step_sizes[:, None, None] * ascent_direction_on_mask[None, :, :], + -self.max_beta, + self.max_beta, + ) + + # ---- Step 1b: Add the regularization ---- # + + # Add a regularization term to fisher info + + if new_fisher_options_on_mask is not None: + # Add the regularization term to construct the Fisher info with prior + # from the Fisher info without prior + cross_term = ( + new_gradient_options_on_mask[:, :, :, None] + @ new_beta_options_on_mask[:, :, None, :] + ) + beta_term = ( + new_beta_options_on_mask[:, :, :, None] + @ new_beta_options_on_mask[:, :, None, :] + ) + new_fisher_options_on_mask += ( + reg_parameter * (cross_term + cross_term.transpose(0, 1, 3, 2)) + + reg_parameter**2 * beta_term + ) + + # Furthermore, add a ridge term to the Fisher info for numerical stability + # This factor decreases log linearly between and initial and final reg + # The decreasing factor is to ensure that the first steps correspond to + # gradient descent steps, as we are too far from the optimum + # to use the Fisher info. + # Note that other schemes seem to work as well: 1 for 20 iterations then + # 1e-6 + # 1 for 20 iterations then 1e-2 (to confirm), or 1 for 20 iterations and + # then + # 1/n_samples. + initial_reg_fisher = 1 + final_reg_fisher = 1e-6 + reg_fisher = initial_reg_fisher * ( + final_reg_fisher / initial_reg_fisher + ) ** (round_number_PQN / self.PQN_num_iters) + + new_fisher_options_on_mask = ( + new_fisher_options_on_mask + + np.diag(np.repeat(reg_fisher, new_fisher_options_on_mask.shape[-1]))[ + None, None, :, : + ] + ) + + # Add regularization term to gradient + new_gradient_options_on_mask += reg_parameter * new_beta_options_on_mask + + # Add regularization term to the nll + new_reg_nll_options_on_mask += ( + 0.5 * reg_parameter * np.sum(new_beta_options_on_mask**2, axis=2) + ) + + # ---- Step 2: Compute best step size, and new values for this step size ---- # + + # This is only done if we are not at the first round of the prox newton + # algorithm, as the first rounds serves only to evaluate the nll, gradient + # and fisher info at the current beta values, and compute the first + # ascent direction. + + if round_number_PQN > 0: + # ---- Step 2a: See which step sizes pass the selection criteria ---- # + + assert reg_nll is not None + reg_nll_on_mask = reg_nll[PQN_mask] + + obj_diff_options_on_mask = ( + reg_nll_on_mask[None, :] - new_reg_nll_options_on_mask + ) # of shape n_steps, n_PQN_genes + + step_sizes = 0.5 ** np.arange(self.PQN_num_iters_ls) + + # Condition 1: Armijo condition + # This condition is also called the first Wolfe condition. + # Reference https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf + admissible_step_size_options_mask = ( + obj_diff_options_on_mask + >= self.PQN_c1 * step_sizes[:, None] * newton_decrement_on_mask[None, :] + ) + + # ---- Step 2b: Identify genes that have diverged, and remove them ---- # + + # For each gene, we check if there is at least one step size that satisfies + # the selection criteria. If there is none, we consider that the gene has + # diverged: we remove all such genes from the PQN_mask + # and add them to the PQN_diverged_mask. + + diverged_gene_mask_in_current_PQN_mask = np.all( + ~admissible_step_size_options_mask, axis=0 + ) + + # Remove these diverged genes for which we cannot find + # a correct step size + + PQN_diverged_mask[PQN_mask] = diverged_gene_mask_in_current_PQN_mask + PQN_mask[PQN_mask] = ~diverged_gene_mask_in_current_PQN_mask + + # Restrict all the quantities defined on the prox newton + # mask to the new prox newton mask + + obj_diff_options_on_mask = obj_diff_options_on_mask[ + :, ~diverged_gene_mask_in_current_PQN_mask + ] + reg_nll_on_mask = reg_nll_on_mask[~diverged_gene_mask_in_current_PQN_mask] + beta_on_mask = beta_on_mask[~diverged_gene_mask_in_current_PQN_mask] + admissible_step_size_options_mask = admissible_step_size_options_mask[ + :, ~diverged_gene_mask_in_current_PQN_mask + ] + new_reg_nll_options_on_mask = new_reg_nll_options_on_mask[ + :, ~diverged_gene_mask_in_current_PQN_mask + ] + new_gradient_options_on_mask = new_gradient_options_on_mask[ + :, ~diverged_gene_mask_in_current_PQN_mask, : + ] + new_beta_options_on_mask = new_beta_options_on_mask[ + :, ~diverged_gene_mask_in_current_PQN_mask, : + ] + + new_fisher_options_on_mask = new_fisher_options_on_mask[ + :, ~diverged_gene_mask_in_current_PQN_mask, :, : + ] + + # ---- Step 2c: Find the best step size for each gene ---- # + + # Here, we find the best step size for each gene that satisfies the + # selection criteria (i.e. the largest). + # We do this by finding the first index for which + # the admissible step size mask is True. + # We then create the new beta, gradient, fisher info and reg nll by + # taking the option corresponding to the best step size + + new_step_size_index = np.argmax(admissible_step_size_options_mask, axis=0) + arange_PQN = np.arange(len(new_step_size_index)) + + new_beta_on_mask = new_beta_options_on_mask[new_step_size_index, arange_PQN] + new_gradient_on_mask = new_gradient_options_on_mask[ + new_step_size_index, arange_PQN + ] + new_fisher_on_mask = new_fisher_options_on_mask[ + new_step_size_index, arange_PQN + ] + + obj_diff_on_mask = obj_diff_options_on_mask[new_step_size_index, arange_PQN] + + new_reg_nll_on_mask = new_reg_nll_options_on_mask[ + new_step_size_index, arange_PQN + ] + + # ---- Step 2d: Update the beta values and the reg_nll values ---- # + + beta[PQN_mask] = new_beta_on_mask + reg_nll[PQN_mask] = new_reg_nll_on_mask + + # ---- Step 2e: Check for convergence of the method ---- # + + convergence_mask = ( + np.abs(obj_diff_on_mask) + / ( + np.maximum( + np.maximum( + np.abs(new_reg_nll_on_mask), + np.abs(reg_nll_on_mask), + ), + 1, + ) + ) + < self.PQN_ftol + ) + + # ---- Step 2f: Remove converged genes from the mask ---- # + PQN_mask[PQN_mask] = ~convergence_mask + + # If we reach the max number of iterations, we stop + if round_number_PQN == self.PQN_num_iters: + # In this case, we are finished. + return { + "beta": beta, + "PQN_diverged_mask": PQN_diverged_mask | PQN_mask, + "irls_diverged_mask": shared_states[0]["irls_diverged_mask"], + } + + # We restrict all quantities to the new mask + + new_gradient_on_mask = new_gradient_on_mask[~convergence_mask] + new_beta_on_mask = new_beta_on_mask[~convergence_mask] + new_fisher_on_mask = new_fisher_on_mask[~convergence_mask] + + # Note, this is the old beta + beta_on_mask = beta_on_mask[~convergence_mask] + + else: + # In this case, we are at the first round of the prox newton algorithm + # In this case, we simply instantiate the new values to the first + # values that were computed in the local steps, to be able to compute + # the first ascent direction. + beta_on_mask = None + new_gradient_on_mask = new_gradient_options_on_mask[0] + new_beta_on_mask = new_beta_options_on_mask[0] + new_fisher_on_mask = new_fisher_options_on_mask[0] + + # Set the nll + reg_nll[PQN_mask] = new_reg_nll_options_on_mask[0] + + # ---- Step 3: Compute the gradient scaling matrix ---- # + + gradient_scaling_matrix_on_mask = compute_gradient_scaling_matrix_fisher( + fisher=new_fisher_on_mask, + backend=self.joblib_backend, + num_jobs=self.num_jobs, + joblib_verbosity=self.joblib_verbosity, + batch_size=self.irls_batch_size, + ) + + # ---- Step 4: Compute the ascent direction and the newton decrement ---- # + + ( + ascent_direction_on_mask, + newton_decrement_on_mask, + ) = compute_ascent_direction_decrement( + gradient_scaling_matrix=gradient_scaling_matrix_on_mask, + gradient=new_gradient_on_mask, + beta=new_beta_on_mask, + max_beta=self.max_beta, + ) + + round_number_PQN += 1 + + return { + "beta": beta, + "PQN_mask": PQN_mask, + "PQN_diverged_mask": PQN_diverged_mask, + "ascent_direction_on_mask": ascent_direction_on_mask, + "newton_decrement_on_mask": newton_decrement_on_mask, + "round_number_PQN": round_number_PQN, + "global_reg_nll": reg_nll, + "irls_diverged_mask": shared_states[0]["irls_diverged_mask"], + } diff --git a/fedpydeseq2/core/fed_algorithms/fed_PQN/utils.py b/fedpydeseq2/core/fed_algorithms/fed_PQN/utils.py new file mode 100644 index 0000000..48aaeaa --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_PQN/utils.py @@ -0,0 +1,257 @@ +"""Utility functions for the proximal Newton optimization. + +This optimization is used in the catching of the IRLS algorithm. +""" + + +import numpy as np +from joblib import Parallel +from joblib import delayed +from joblib import parallel_backend + +from fedpydeseq2.core.utils.negative_binomial import mu_grid_nb_nll + + +def make_fisher_gradient_nll_step_sizes_batch( + design_matrix: np.ndarray, + size_factors: np.ndarray, + beta: np.ndarray, + dispersions: np.ndarray, + counts: np.ndarray, + ascent_direction: np.ndarray | None, + step_sizes: np.ndarray | None, + beta_min: float | None, + beta_max: float | None, + min_mu: float = 0.5, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Make local gradient, fisher matrix, and nll for multiple steps. + + Parameters + ---------- + design_matrix : ndarray + The design matrix, of shape (n_obs, n_params). + size_factors : ndarray + The size factors, of shape (n_obs). + beta : ndarray + The log fold change matrix, of shape (batch_size, n_params). + dispersions : ndarray + The dispersions, of shape (batch_size). + counts : ndarray + The counts, of shape (n_obs,batch_size). + ascent_direction : np.ndarray + The ascent direction, of shape (batch_size, n_params). + step_sizes: np.ndarray + A list of step sizes to evaluate, of size (n_steps, ). + beta_min: float + The minimum value tolerated for beta. + beta_max: float + The maximum value tolerated for beta. + min_mu : float + Lower bound on estimated means, to ensure numerical stability. + + Returns + ------- + H : Optional[ndarray] + The Fisher information matrix, of shape + (n_steps, batch_size, n_params, n_params). + gradient : ndarray + The gradient, of shape (n_steps, batch_size, n_params). + nll : ndarray + The nll evaluations on all steps, of size (n_steps, batch_size). + """ + # If no ascent direction is provided, we do not need to compute the grid + # of beta values, but only the current beta value, where we unsqueeze the + # first dimension to make it compatible with the rest of the code + # This is the case when we are at the first iteration of the optimization + if ascent_direction is None and step_sizes is None: + beta_grid = np.clip( + beta[None, :, :], + beta_min, + beta_max, + ) # of shape (n_steps, batch_size, n_params) + + # In this case, we compute the grid of beta values, by moving in the direction + # of the ascent direction, by the step sizes + else: + assert isinstance(step_sizes, np.ndarray) and isinstance( + ascent_direction, np.ndarray + ) + beta_grid = np.clip( + beta[None, :, :] - step_sizes[:, None, None] * ascent_direction[None, :, :], + beta_min, + beta_max, + ) # of shape (n_steps, batch_size, n_params) + + mu_grid = size_factors[None, None, :] * np.exp( + (design_matrix[None, None, :, :] @ beta_grid[:, :, :, None]).squeeze(axis=3) + ) # of shape (n_steps, batch_size, n_obs) + mu_grid = np.maximum( + mu_grid, + min_mu, + ) + + # --- Step 1: Compute the gradient ----# + + gradient_term_1 = -(design_matrix.T @ counts).T[ + None, :, : + ] # shape (1, batch_size, n_params) + gradient_term_2 = ( + design_matrix.T[None, None, :, :] + @ ( + (1 / dispersions[None, :, None] + counts.T[None, :, :]) + * mu_grid + / (1 / dispersions[None, :, None] + mu_grid) # n_steps, batch_size, n_obs + )[:, :, :, None] + ).squeeze( + 3 + ) # Shape n_steps, batch_size, n_params + gradient = gradient_term_1 + gradient_term_2 + + # ---- Step 2: Compute the Fisher matrix ----# + + W = mu_grid / (1.0 + mu_grid * dispersions[None, :, None]) + expanded_design = design_matrix[ + None, None, :, : + ] # of shape (1, 1, n_obs, n_params) + assert W is not None + H = (expanded_design * W[:, :, :, None]).transpose(0, 1, 3, 2) @ expanded_design + # H of size (n_steps, batch_size, n_params, n_params) + + # Get the mu_grid + nll = mu_grid_nb_nll(counts, mu_grid, dispersions) + + return H, gradient, nll + + +def compute_gradient_scaling_matrix_fisher( + fisher: np.ndarray, + backend: str, + num_jobs: int, + joblib_verbosity: int, + batch_size: int, +): + """Compute the gradient scaling matrix using the Fisher information. + + In this case, we simply invert the provided Fisher matrix to get the gradient + scaling matrix. + + Parameters + ---------- + fisher : ndarray + The Fisher matrix, of shape (n_genes, n_params, n_params) + backend : str + The backend to use for parallelization + num_jobs : int + The number of cpus to use + joblib_verbosity : int + The verbosity level of joblib + batch_size : int + The batch size to use for the computation + + Returns + ------- + ndarray + The gradient scaling matrix, of shape (n_genes, n_params, n_params) + """ + with parallel_backend(backend): + res = Parallel(n_jobs=num_jobs, verbose=joblib_verbosity)( + delayed(np.linalg.inv)( + fisher[i : i + batch_size], + ) + for i in range(0, len(fisher), batch_size) + ) + if len(res) > 0: + gradient_scaling_matrix = np.concatenate(res) + else: + gradient_scaling_matrix = np.zeros_like(fisher) + + return gradient_scaling_matrix + + +def compute_ascent_direction_decrement( + gradient_scaling_matrix: np.ndarray, + gradient: np.ndarray, + beta: np.ndarray, + max_beta: float, +): + """Compute the ascent direction and decrement. + + We do this from the gradient scaling matrix, the gradient, + the beta and the max beta, which embodies the box constraints. + + Please look at this paper for the precise references to the equations: + https://www.cs.utexas.edu/~inderjit/public_papers/pqnj_sisc10.pdf + + By ascent direction, we mean that the direction we compute is positively + correlated with the gradient. As our aim is to minimize the function, + we want to move in the opposite direction of the ascent direction, but + it is simpler to compute the ascent direction to avoid sign errors. + + Parameters + ---------- + gradient_scaling_matrix : np.ndarray + The gradient scaling matrix, of shape (n_genes, n_params, n_params). + gradient : np.ndarray + The gradient per gene, of shape (n_genes, n_params). + beta : np.ndarray + Beta on those genes, of shape (n_genes, n_params). + max_beta : float + The max absolute value for beta. + + Returns + ------- + ascent_direction : np.ndarray + The new ascent direction, of shape (n_genes, n_params). + newton_decrement : np.ndarray + The newton decrement associated to these ascent directions + of shape (n_genes, ) + + """ + # ---- Step 1: compute first index set ---- # + # See https://www.cs.utexas.edu/~inderjit/public_papers/pqnj_sisc10.pdf + # equation 2.2 + + lower_binding = (beta < -max_beta + 1e-14) & (gradient > 0) + upper_binding = (beta > max_beta - 1e-14) & (gradient < 0) + first_index_mask = lower_binding | upper_binding # of shape (n_genes, n_params) + + # Set to zero the gradient scaling matrix on the first index + + n_params = beta.shape[1] + + gradient_scaling_matrix[ + np.repeat(first_index_mask[:, :, None], repeats=n_params, axis=2) + ] = 0 + gradient_scaling_matrix[ + np.repeat(first_index_mask[:, None, :], repeats=n_params, axis=1) + ] = 0 + + ascent_direction = (gradient_scaling_matrix @ gradient[:, :, None]).squeeze( + axis=2 + ) # of shape (n_genes, n_params) + + # ---- Step 2: Compute the second index set ---- # + # See https://www.cs.utexas.edu/~inderjit/public_papers/pqnj_sisc10.pdf + # equation 2.3 + + lower_binding = (beta < -max_beta + 1e-14) & (ascent_direction > 0) + upper_binding = (beta > max_beta - 1e-14) & (ascent_direction < 0) + second_index_mask = lower_binding | upper_binding + + # Set to zero the gradient scaling matrix on the second index + + gradient_scaling_matrix[ + np.repeat(second_index_mask[:, :, None], repeats=n_params, axis=2) + ] = 0 + gradient_scaling_matrix[ + np.repeat(second_index_mask[:, None, :], repeats=n_params, axis=1) + ] = 0 + + # ---- Step 3: Compute the ascent direction and Newton decrement ---- # + + ascent_direction = gradient_scaling_matrix @ gradient[:, :, None] + newton_decrement = (gradient[:, None, :] @ ascent_direction).squeeze(axis=(1, 2)) + + ascent_direction = ascent_direction.squeeze(axis=2) + + return ascent_direction, newton_decrement diff --git a/fedpydeseq2/core/fed_algorithms/fed_irls/__init__.py b/fedpydeseq2/core/fed_algorithms/fed_irls/__init__.py new file mode 100644 index 0000000..48bcd02 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_irls/__init__.py @@ -0,0 +1,3 @@ +"""Module which contains the Mixin in charge of performing FedIRLS.""" + +from fedpydeseq2.core.fed_algorithms.fed_irls.fed_irls import FedIRLS diff --git a/fedpydeseq2/core/fed_algorithms/fed_irls/fed_irls.py b/fedpydeseq2/core/fed_algorithms/fed_irls/fed_irls.py new file mode 100644 index 0000000..ea7882e --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_irls/fed_irls.py @@ -0,0 +1,197 @@ +"""Module containing the ComputeLFC method.""" + +from substrafl.nodes import AggregationNode + +from fedpydeseq2.core.fed_algorithms.fed_irls.substeps import AggMakeIRLSUpdate +from fedpydeseq2.core.fed_algorithms.fed_irls.substeps import LocMakeIRLSSummands +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step + + +class FedIRLS( + LocMakeIRLSSummands, + AggMakeIRLSUpdate, +): + r"""Mixin class to implement the LFC computation algorithm. + + The goal of this class is to implement the IRLS algorithm specifically applied + to the negative binomial distribution, with fixed dispersion parameter (only + the mean parameter, expressed as the exponential of the log fold changes times + the design matrix, is estimated). This algorithm is caught with another method on + the genes on which it fails. + + To the best of our knowledge, there is no explicit implementation of IRLS for the + negative binomial in a federated setting. However, the steps of IRLS are akin + to the ones of a Newton-Raphson algorithm, with the difference that the Hessian + matrix is replaced by the Fisher information matrix. + + Let us recall the steps of the IRLS algorithm for one gene (this method then + implements these iterations for all genes in parallell). + We want to estimate the log fold changes :math:`\beta` from the counts :math:`y` + and the design matrix :math:`X`. The negative binomial likelihood is given by: + + .. math:: + \mathcal{L}(\beta) = \sum_{i=1}^n \left( y_i \log(\mu_i) - + (y_i + \alpha^{-1}) \log(\mu_i + \alpha^{-1}) \right) + \text{const}(y, \alpha) + + where :math:`\mu_i = \gamma_i\exp(X_i \cdot \beta)` and :math:`\alpha` is + the dispersion parameter. + + Given an iterate :math:`\beta_k`, the IRLS algorithm computes the next iterate + :math:`\beta_{k+1}` as follows. + + First, we compute the mean parameter :math:`\mu_k` from the current iterate, using + the formula of the log fold changes: + + .. math:: + (\mu_{k})_i = \gamma_i \exp(X_i \cdot \beta_k) + + In practice, we trim the values of :math:`\mu_k` to a minimum value to ensure + numerical stability. + + Then, we compute the weight matrix :math:`W_k` from the current iterate + :math:`\beta_k`, which is a diagonal matrix with diagonal elements: + + .. math:: + (W_k)_{ii} = \frac{\mu_{k,i}}{1 + \mu_{k,i} \alpha} + + where :math:`\alpha` is the dispersion parameter. + This weight matrix is used to compute both the estimated variance (or hat matrix) + and the feature vector :math:`z_k`: + + .. math:: + z_k = \log\left(\frac{\mu_k}{\gamma}\right) + \frac{y - \mu_k}{\mu_k} + + The estimated variance is given by: + + .. math:: + H_k = X^T W_k X + + The update step is then given by: + + .. math:: + \beta_{k+1} = (H_k)^{-1} X^T W_k z_k + + This is akin to the Newton-Raphson algorithm, with the + Hessian matrix replaced by the Fisher information, and the gradient replaced by the + feature vector. + + Methods + ------- + run_fed_irls + Run the IRLS algorithm. + + """ + + def run_fed_irls( + self, + train_data_nodes: list, + aggregation_node: AggregationNode, + local_states: dict, + input_shared_state: dict, + round_idx: int, + clean_models: bool = True, + refit_mode: bool = False, + ): + """Run the IRLS algorithm. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + input_shared_state: dict + Shared state with the following keys: + - beta: ndarray + The current beta, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). + - global_nll: ndarray + The global_nll of the current beta from the previous beta, of shape\ + (n_non_zero_genes,). + - round_number_irls: int + The current round number of the IRLS algorithm. + + round_idx: int + The current round. + + clean_models: bool + If True, the models are cleaned. + + refit_mode: bool + Whether to run on `refit_adata`s instead of `local_adata`s. + (default: False). + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + global_irls_summands_nlls_shared_state: dict + Shared states containing the final IRLS results. + It contains nothing for now. + - beta: ndarray + The current beta, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). + - global_nll: ndarray + The global_nll of the current beta from the previous beta, of shape\ + (n_non_zero_genes,). + - round_number_irls: int + The current round number of the IRLS algorithm. + + round_idx: int + The updated round index. + + """ + #### ---- Main training loop ---- ##### + + global_irls_summands_nlls_shared_state = input_shared_state + + for _ in range(self.irls_num_iter + 1): + # ---- Compute local IRLS summands and nlls ---- # + + ( + local_states, + local_irls_summands_nlls_shared_states, + round_idx, + ) = local_step( + local_method=self.make_local_irls_summands_and_nlls, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=global_irls_summands_nlls_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute local IRLS summands and nlls.", + clean_models=clean_models, + method_params={"refit_mode": refit_mode}, + ) + + # ---- Compute global IRLS update and nlls ---- # + + global_irls_summands_nlls_shared_state, round_idx = aggregation_step( + aggregation_method=self.make_global_irls_update, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=local_irls_summands_nlls_shared_states, + round_idx=round_idx, + description="Update the log fold changes and nlls in IRLS.", + clean_models=clean_models, + ) + + return local_states, global_irls_summands_nlls_shared_state, round_idx diff --git a/fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py b/fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py new file mode 100644 index 0000000..81041b0 --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py @@ -0,0 +1,381 @@ +"""Module to implement the substeps for the fitting of log fold changes. + +This module contains all these substeps as mixin classes. +""" + +from typing import Any + +import numpy as np +from anndata import AnnData +from joblib import Parallel +from joblib import delayed +from joblib import parallel_backend +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms.fed_irls.utils import ( + make_irls_update_summands_and_nll_batch, +) +from fedpydeseq2.core.utils.compute_lfc_utils import get_lfc_utils_from_gene_mask_adata +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocMakeIRLSSummands: + """Mixin to make the summands for the IRLS algorithm. + + Attributes + ---------- + local_adata : AnnData + The local AnnData object. + num_jobs : int + The number of cpus to use. + joblib_verbosity : int + The verbosity of the joblib backend. + joblib_backend : str + The backend to use for the joblib parallelization. + irls_batch_size : int + The batch size to use for the IRLS algorithm. + min_mu : float + The minimum value for the mu parameter. + irls_num_iter : int + The number of iterations for the IRLS algorithm. + + Methods + ------- + make_local_irls_summands_and_nlls + A remote_data method. Makes the summands for the IRLS algorithm. + It also passes on the necessary global quantities. + + """ + + local_adata: AnnData + refit_adata: AnnData + num_jobs: int + joblib_verbosity: int + joblib_backend: str + irls_batch_size: int + min_mu: float + irls_num_iter: int + + @remote_data + @log_remote_data + @reconstruct_adatas + def make_local_irls_summands_and_nlls( + self, + data_from_opener: AnnData, + shared_state: dict[str, Any], + refit_mode: bool = False, + ): + """Make the summands for the IRLS algorithm. + + This functions does two main operations: + + 1) It computes the summands for the beta update. + 2) It computes the local quantities to compute the global_nll + of the current beta + + + Parameters + ---------- + data_from_opener : AnnData + Not used. + + shared_state : dict + A dictionary containing the following + keys: + - beta: ndarray + The current beta, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). + - global_nll: ndarray + The global_nll of the current beta from the previous beta, of shape\ + (n_non_zero_genes,). + - round_number_irls: int + The current round number of the IRLS algorithm. + + refit_mode : bool + Whether to run on `refit_adata`s instead of `local_adata`s. + (default: False). + + Returns + ------- + dict + The state to share to the server. + It contains the following fields: + - beta: ndarray + The current beta, of shape (n_non_zero_genes, n_params). + - local_nll: ndarray + The local nll of the current beta, of shape (n_irls_genes,). + - local_hat_matrix: ndarray + The local hat matrix, of shape (n_irls_genes, n_params, n_params). + n_irsl_genes is the number of genes that are still active (non zero + gene names on the irls_mask). + - local_features: ndarray + The local features, of shape (n_irls_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). + - global_nll: ndarray + The global_nll of the current beta of shape + (n_non_zero_genes,). + This parameter is simply passed to the next shared state + - round_number_irls: int + The current round number of the IRLS algorithm. + This round number is not updated here. + + """ + if refit_mode: + adata = self.refit_adata + else: + adata = self.local_adata + + # Put all elements in the shared state in readable variables + beta = shared_state["beta"] + irls_mask = shared_state["irls_mask"] + irls_diverged_mask = shared_state["irls_diverged_mask"] + global_nll = shared_state["global_nll"] + round_number_irls = shared_state["round_number_irls"] + + # Get the quantitie stored in the adata + disp_param_name = adata.uns["_irls_disp_param_name"] + + # If this is the first round, save the beta init in a field of the local adata + if round_number_irls == 0: + adata.uns["_irls_beta_init"] = beta.copy() + + ( + irls_gene_names, + design_matrix, + size_factors, + counts, + dispersions, + beta_genes, + ) = get_lfc_utils_from_gene_mask_adata( + adata, irls_mask, beta=beta, disp_param_name=disp_param_name + ) + + # ---- Compute the summands for the beta update and the local nll ---- # + + with parallel_backend(self.joblib_backend): + res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)( + delayed(make_irls_update_summands_and_nll_batch)( + design_matrix, + size_factors, + beta_genes[i : i + self.irls_batch_size], + dispersions[i : i + self.irls_batch_size], + counts[:, i : i + self.irls_batch_size], + self.min_mu, + ) + for i in range(0, len(beta_genes), self.irls_batch_size) + ) + + if len(res) == 0: + H = np.zeros((0, beta.shape[1], beta.shape[1])) + y = np.zeros((0, beta.shape[1])) + local_nll = np.zeros(0) + else: + H = np.concatenate([r[0] for r in res]) + y = np.concatenate([r[1] for r in res]) + local_nll = np.concatenate([r[2] for r in res]) + + # Create the shared state + return { + "beta": beta, + "local_nll": local_nll, + "local_hat_matrix": H, + "local_features": y, + "irls_gene_names": irls_gene_names, + "irls_diverged_mask": irls_diverged_mask, + "irls_mask": irls_mask, + "global_nll": global_nll, + "round_number_irls": round_number_irls, + } + + +class AggMakeIRLSUpdate: + """Mixin class to aggregate IRLS summands. + + Please refer to the method make_local_irls_summands_and_nlls for more. + + Attributes + ---------- + num_jobs : int + The number of cpus to use. + joblib_verbosity : int + The verbosity of the joblib backend. + joblib_backend : str + The backend to use for the joblib parallelization. + irls_batch_size : int + The batch size to use for the IRLS algorithm. + max_beta : float + The maximum value for the beta parameter. + beta_tol : float + The tolerance for the beta parameter. + irls_num_iter : int + The number of iterations for the IRLS algorithm. + + Methods + ------- + make_global_irls_update + A remote method. Aggregates the local quantities to create + the global IRLS update. It also updates the masks indicating which genes + have diverged or converged according to the deviance. + + """ + + num_jobs: int + joblib_verbosity: int + joblib_backend: str + irls_batch_size: int + max_beta: float + beta_tol: float + irls_num_iter: int + + @remote + @log_remote + def make_global_irls_update(self, shared_states: list[dict]) -> dict[str, Any]: + """Make the summands for the IRLS algorithm. + + The role of this function is twofold. + + 1) It computes the global_nll and updates the masks according to the deviance, + for the beta values that have been computed in the previous round. + + 2) It aggregates the local hat matrix and features to solve the linear system + and get the new beta values. + + Parameters + ---------- + shared_states: list[dict] + A list of dictionaries containing the following + keys: + - local_hat_matrix: ndarray + The local hat matrix, of shape (n_irls_genes, n_params, n_params). + n_irsl_genes is the number of genes that are still active (non zero + gene names on the irls_mask). + - local_features: ndarray + The local features, of shape (n_irls_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). + - global_nll: ndarray + The global_nll of the current beta from the previous beta, of shape\ + (n_non_zero_genes,). + - round_number_irls: int + The current round number of the IRLS algorithm. + + Returns + ------- + dict[str, Any] + A dictionary containing all the necessary info to run IRLS. + It contains the following fields: + - beta: ndarray + The log fold changes, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). + - global_nll: ndarray + The global_nll of the current beta from the previous beta, of shape\ + (n_non_zero_genes,). + - round_number_irls: int + The current round number of the IRLS algorithm. + + + """ + # Load main params from the first state + beta = shared_states[0]["beta"] + irls_mask = shared_states[0]["irls_mask"] + irls_diverged_mask = shared_states[0]["irls_diverged_mask"] + global_nll = shared_states[0]["global_nll"] + round_number_irls = shared_states[0]["round_number_irls"] + + # ---- Step 0: Aggregate the local hat matrix, features and global_nll ---- # + + global_hat_matrix = sum([state["local_hat_matrix"] for state in shared_states]) + global_features = sum([state["local_features"] for state in shared_states]) + global_nll_on_irls_mask = sum([state["local_nll"] for state in shared_states]) + + # ---- Step 1: update global_nll and masks ---- # + + # The first round needs to be handled separately + if round_number_irls == 0: + # In that case, the irls_masks consists in all True values + # We only need set the initial global_nll + global_nll = global_nll_on_irls_mask + + else: + old_global_nll = global_nll.copy() + old_irls_mask = irls_mask.copy() + + global_nll[irls_mask] = global_nll_on_irls_mask + + # Set the new masks with the dev ratio and beta values + deviance_ratio = np.abs(2 * global_nll - 2 * old_global_nll) / ( + np.abs(2 * global_nll) + 0.1 + ) + irls_diverged_mask = irls_diverged_mask | ( + np.abs(beta) > self.max_beta + ).any(axis=1) + + irls_mask = irls_mask & (deviance_ratio > self.beta_tol) + irls_mask = irls_mask & ~irls_diverged_mask + new_mask_in_old_mask = (irls_mask & old_irls_mask)[old_irls_mask] + global_hat_matrix = global_hat_matrix[new_mask_in_old_mask] + global_features = global_features[new_mask_in_old_mask] + + if round_number_irls == self.irls_num_iter: + # In this case, we must prepare the switch to fed prox newton + return { + "beta": beta, + "irls_diverged_mask": irls_diverged_mask, + "irls_mask": irls_mask, + "global_nll": global_nll, + "round_number_irls": round_number_irls, + } + + # ---- Step 2: Solve the system to compute beta ---- # + + ridge_factor = np.diag(np.repeat(1e-6, global_hat_matrix.shape[1])) + with parallel_backend(self.joblib_backend): + res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)( + delayed(np.linalg.solve)( + global_hat_matrix[i : i + self.irls_batch_size] + ridge_factor, + global_features[i : i + self.irls_batch_size], + ) + for i in range(0, len(global_hat_matrix), self.irls_batch_size) + ) + if len(res) > 0: + beta_hat = np.concatenate(res) + else: + beta_hat = np.zeros((0, global_hat_matrix.shape[1])) + + # TODO : it would be cleaner to pass an update, which is None at the first + # round. That way we do not update beta in a different step its evaluation. + + # Update the beta + beta[irls_mask] = beta_hat + + round_number_irls = round_number_irls + 1 + + return { + "beta": beta, + "irls_diverged_mask": irls_diverged_mask, + "irls_mask": irls_mask, + "global_nll": global_nll, + "round_number_irls": round_number_irls, + } diff --git a/fedpydeseq2/core/fed_algorithms/fed_irls/utils.py b/fedpydeseq2/core/fed_algorithms/fed_irls/utils.py new file mode 100644 index 0000000..a346b7c --- /dev/null +++ b/fedpydeseq2/core/fed_algorithms/fed_irls/utils.py @@ -0,0 +1,81 @@ +"""Module to implement the utilities of the IRLS algorithm. + +Most of these functions have the _batch suffix, which means that they are +vectorized to work over batches of genes in the parralel_backend file in +the same module. +""" + +import numpy as np + +from fedpydeseq2.core.utils.negative_binomial import grid_nb_nll + + +def make_irls_update_summands_and_nll_batch( + design_matrix: np.ndarray, + size_factors: np.ndarray, + beta: np.ndarray, + dispersions: np.ndarray, + counts: np.ndarray, + min_mu: float, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Make the summands for the IRLS algorithm for a given set of genes. + + Parameters + ---------- + design_matrix : ndarray + The design matrix, of shape (n_obs, n_params). + size_factors : ndarray + The size factors, of shape (n_obs). + beta : ndarray + The log fold change matrix, of shape (batch_size, n_params). + dispersions : ndarray + The dispersions, of shape (batch_size). + counts : ndarray + The counts, of shape (n_obs,batch_size). + min_mu : float + Lower bound on estimated means, to ensure numerical stability. + + Returns + ------- + H : ndarray + The H matrix, of shape (batch_size, n_params, n_params). + y : ndarray + The y vector, of shape (batch_size, n_params). + nll : ndarray + The negative binomial negative log-likelihood, of shape (batch_size). + """ + max_limit = np.log(1e100) + design_matrix_time_beta_T = design_matrix @ beta.T + mask_nan = design_matrix_time_beta_T > max_limit + + # In order to avoid overflow and np.inf, we replace all big values in the + # design_matrix_time_beta_T with 0., then we carry the computation normally, and + # we modify the final quantity with their true value for the inputs were + # exp_design_matrix_time_beta_T should have taken values >> 1 + exp_design_matrix_time_beta_T = np.zeros( + design_matrix_time_beta_T.shape, dtype=design_matrix_time_beta_T.dtype + ) + exp_design_matrix_time_beta_T[~mask_nan] = np.exp( + design_matrix_time_beta_T[~mask_nan] + ) + mu = size_factors[:, None] * exp_design_matrix_time_beta_T + + mu = np.maximum(mu, min_mu) + + W = mu / (1.0 + mu * dispersions[None, :]) + + dispersions_broadcast = np.broadcast_to( + dispersions, (mu.shape[0], dispersions.shape[0]) + ) + W[mask_nan] = 1.0 / dispersions_broadcast[mask_nan] + + z = np.log(mu / size_factors[:, None]) + (counts - mu) / mu + z[mask_nan] = design_matrix_time_beta_T[mask_nan] - 1.0 + + H = (design_matrix.T[:, :, None] * W).transpose(2, 0, 1) @ design_matrix[None, :, :] + y = (design_matrix.T @ (W * z)).T + + mu[mask_nan] = np.inf + nll = grid_nb_nll(counts, mu, dispersions, mask_nan) + + return H, y, nll diff --git a/fedpydeseq2/core/utils/__init__.py b/fedpydeseq2/core/utils/__init__.py new file mode 100644 index 0000000..073e4eb --- /dev/null +++ b/fedpydeseq2/core/utils/__init__.py @@ -0,0 +1,14 @@ +from fedpydeseq2.core.utils.aggregation import aggregate_means +from fedpydeseq2.core.utils.design_matrix import build_design_matrix +from fedpydeseq2.core.utils.mle import batch_mle_grad +from fedpydeseq2.core.utils.mle import batch_mle_update +from fedpydeseq2.core.utils.mle import global_grid_cr_loss +from fedpydeseq2.core.utils.mle import local_grid_summands +from fedpydeseq2.core.utils.mle import single_mle_grad +from fedpydeseq2.core.utils.mle import vec_loss +from fedpydeseq2.core.utils.negative_binomial import vec_nb_nll_grad +from fedpydeseq2.core.utils.pipe_steps import aggregation_step +from fedpydeseq2.core.utils.pipe_steps import local_step +from fedpydeseq2.core.utils.stat_utils import build_contrast +from fedpydeseq2.core.utils.stat_utils import build_contrast_vector +from fedpydeseq2.core.utils.stat_utils import wald_test diff --git a/fedpydeseq2/core/utils/aggregation.py b/fedpydeseq2/core/utils/aggregation.py new file mode 100644 index 0000000..51d4c75 --- /dev/null +++ b/fedpydeseq2/core/utils/aggregation.py @@ -0,0 +1,42 @@ +""" +Aggregation functions. + +Copy-pasted from the CancerLINQ repo. +""" + +from typing import Any + +import numpy as np + + +# pylint: disable=deprecated-typing-alias +def aggregate_means( + local_means: list[Any], n_local_samples: list[int], filter_nan: bool = False +): + """Aggregate local means. + + Aggregate the local means into a global mean by using the local number of samples. + + Parameters + ---------- + local_means : list[Any] + list of local means. Could be array, float, Series. + n_local_samples : list[int] + list of number of samples used for each local mean. + filter_nan : bool, optional + Filter NaN values in the local means, by default False. + + Returns + ------- + Any + Aggregated mean. Same type of the local means + """ + tot_samples = 0 + tot_mean = np.zeros_like(local_means[0]) + for mean, n_sample in zip(local_means, n_local_samples, strict=False): + if filter_nan: + mean = np.nan_to_num(mean, nan=0, copy=False) + tot_mean += mean * n_sample + tot_samples += n_sample + + return tot_mean / tot_samples diff --git a/fedpydeseq2/core/utils/compute_lfc_utils.py b/fedpydeseq2/core/utils/compute_lfc_utils.py new file mode 100644 index 0000000..8c05eb1 --- /dev/null +++ b/fedpydeseq2/core/utils/compute_lfc_utils.py @@ -0,0 +1,78 @@ +import anndata as ad +import numpy as np + + +def get_lfc_utils_from_gene_mask_adata( + adata: ad.AnnData, + gene_mask: np.ndarray | None, + disp_param_name: str, + beta: np.ndarray | None = None, + lfc_param_name: str | None = None, +) -> tuple[list[str], np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Get the necessary data for LFC computations from the local adata and genes. + + Parameters + ---------- + adata : ad.AnnData + The local AnnData object. + + gene_mask : np.ndarray or None + The mask of genes to use for the IRLS algorithm. + This mask identifies the genes in the non_zero_gene_names. + If None, all non zero genes are used. + + disp_param_name : str + The name of the dispersion parameter in the adata.varm. + + beta : Optional[np.ndarray] + The log fold change values, of shape (n_non_zero_genes,). + + lfc_param_name: Optional[str] + The name of the lfc parameter in the adata.varm. + Is incompatible with beta. + + Returns + ------- + gene_names : list[str] + The names of the genes to use for the IRLS algorithm. + design_matrix : np.ndarray + The design matrix. + size_factors : np.ndarray + The size factors. + counts : np.ndarray + The count matrix from the local adata. + dispersions : np.ndarray + The dispersions from the local adata. + beta_on_mask : np.ndarray + The log fold change values on the mask. + """ + # Check that one of beta or lfc_param_name is not None + assert (beta is not None) ^ ( + lfc_param_name is not None + ), "One of beta or lfc_param_name must be not None" + + # Get non zero genes + non_zero_genes_names = adata.var_names[adata.varm["non_zero"]] + + # Get the irls genes + if gene_mask is None: + gene_names = non_zero_genes_names + else: + gene_names = non_zero_genes_names[gene_mask] + + # Get beta + if lfc_param_name is not None: + beta_on_mask = adata[:, gene_names].varm[lfc_param_name].to_numpy() + elif gene_mask is not None: + assert beta is not None # for mypy + beta_on_mask = beta[gene_mask] + else: + assert beta is not None # for mypy + beta_on_mask = beta.copy() + + design_matrix = adata.obsm["design_matrix"].values + size_factors = adata.obsm["size_factors"] + counts = adata[:, gene_names].X + dispersions = adata[:, gene_names].varm[disp_param_name] + + return gene_names, design_matrix, size_factors, counts, dispersions, beta_on_mask diff --git a/fedpydeseq2/core/utils/design_matrix.py b/fedpydeseq2/core/utils/design_matrix.py new file mode 100644 index 0000000..8546fff --- /dev/null +++ b/fedpydeseq2/core/utils/design_matrix.py @@ -0,0 +1,140 @@ +import warnings + +import numpy as np +import pandas as pd +import pandas.api.types as ptypes + + +def build_design_matrix( + metadata: pd.DataFrame, + design_factors: str | list[str] = "stage", + levels: dict[str, list[str]] | None = None, + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = None, +) -> pd.DataFrame: + """Build design_matrix matrix for DEA. + + Unless specified, the reference factor is chosen alphabetically. + Copied from PyDESeq2, with some modifications specific to fedomics to ensure that + all centers have the same columns + + Parameters + ---------- + metadata : pandas.DataFrame + DataFrame containing metadata information. + Must be indexed by sample barcodes. + + design_factors : str or list + Name of the columns of metadata to be used as design_matrix variables. + (default: ``"condition"``). + + levels : dict or None + An optional dictionary of lists of strings specifying the levels of each factor + in the global design, e.g. ``{"condition": ["A", "B"]}``. (default: ``None``). + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors, that should + also be in ``design_factors``. Any factor in ``design_factors`` but not in + ``continuous_factors`` will be considered categorical (default: ``None``). + + Returns + ------- + pandas.DataFrame + A DataFrame with experiment design information (to split cohorts). + Indexed by sample barcodes. + """ + if isinstance( + design_factors, str + ): # if there is a single factor, convert to singleton list + design_factors = [design_factors] + + # Check that factors in the design don't contain underscores. If so, convert + # them to hyphens + if np.any(["_" in factor for factor in design_factors]): + warnings.warn( + """Same factor names in the design contain underscores ('_'). They will + be converted to hyphens ('-').""", + UserWarning, + stacklevel=2, + ) + design_factors = [factor.replace("_", "-") for factor in design_factors] + + # Check that level factors in the design don't contain underscores. If so, convert + # them to hyphens + warning_issued = False + for factor in design_factors: + if ptypes.is_numeric_dtype(metadata[factor]): + continue + if np.any(["_" in value for value in metadata[factor]]): + if not warning_issued: + warnings.warn( + """Some factor levels in the design contain underscores ('_'). + They will be converted to hyphens ('-').""", + UserWarning, + stacklevel=2, + ) + warning_issued = True + metadata[factor] = metadata[factor].apply(lambda x: x.replace("_", "-")) + + if continuous_factors is not None: + for factor in continuous_factors: + if factor not in design_factors: + raise ValueError( + f"Continuous factor '{factor}' not in design factors: " + f"{design_factors}." + ) + categorical_factors = [ + factor for factor in design_factors if factor not in continuous_factors + ] + else: + categorical_factors = design_factors + + if levels is None: + levels = {factor: np.unique(metadata[factor]) for factor in categorical_factors} + + # Check that there is at least one categorical factor + if len(categorical_factors) > 0: + design_matrix = pd.get_dummies(metadata[categorical_factors], drop_first=False) + # Check if there missing levels. If so, add them and set to 0. + for factor in categorical_factors: + for level in levels[factor]: + if f"{factor}_{level}" not in design_matrix.columns: + design_matrix[f"{factor}_{level}"] = 0 + + # Pick the first level as reference. Then, drop the column. + for factor in categorical_factors: + if ref_levels is not None and factor in ref_levels: + ref = ref_levels[factor] + else: + ref = levels[factor][0] + + ref_level_name = f"{factor}_{ref}" + design_matrix.drop(ref_level_name, axis="columns", inplace=True) + + # Add reference level as column name suffix + design_matrix.columns = [ + f"{col}_vs_{ref}" if col.startswith(factor) else col + for col in design_matrix.columns + ] + else: + # There is no categorical factor in the design + design_matrix = pd.DataFrame(index=metadata.index) + + # Add the intercept column + design_matrix.insert(0, "intercept", 1) + + # Convert categorical factors one-hot encodings to int + design_matrix = design_matrix.astype("int") + + # Add continuous factors + if continuous_factors is not None: + for factor in continuous_factors: + # This factor should be numeric + design_matrix[factor] = pd.to_numeric(metadata[factor]) + return design_matrix diff --git a/fedpydeseq2/core/utils/layers/__init__.py b/fedpydeseq2/core/utils/layers/__init__.py new file mode 100644 index 0000000..99a1b25 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/__init__.py @@ -0,0 +1,6 @@ +from fedpydeseq2.core.utils.layers.cooks_layer import prepare_cooks_local +from fedpydeseq2.core.utils.layers.cooks_layer import prepare_cooks_agg +from fedpydeseq2.core.utils.layers.reconstruct_adatas_decorator import ( + reconstruct_adatas, +) +from fedpydeseq2.core.utils.layers.utils import set_mu_layer diff --git a/fedpydeseq2/core/utils/layers/build_layers/__init__.py b/fedpydeseq2/core/utils/layers/build_layers/__init__.py new file mode 100644 index 0000000..79dd22e --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/__init__.py @@ -0,0 +1,31 @@ +"""Module to construct the layers.""" + +from fedpydeseq2.core.utils.layers.build_layers.normed_counts import ( + can_get_normed_counts, + set_normed_counts, +) +from fedpydeseq2.core.utils.layers.build_layers.y_hat import can_get_y_hat, set_y_hat +from fedpydeseq2.core.utils.layers.build_layers.fit_lin_mu_hat import ( + can_get_fit_lin_mu_hat, + set_fit_lin_mu_hat, +) +from fedpydeseq2.core.utils.layers.build_layers.mu_layer import ( + can_set_mu_layer, + set_mu_layer, +) +from fedpydeseq2.core.utils.layers.build_layers.mu_hat import ( + can_get_mu_hat, + set_mu_hat_layer, +) +from fedpydeseq2.core.utils.layers.build_layers.sqerror import ( + can_get_sqerror_layer, + set_sqerror_layer, +) +from fedpydeseq2.core.utils.layers.build_layers.hat_diagonals import ( + can_set_hat_diagonals_layer, + set_hat_diagonals_layer, +) +from fedpydeseq2.core.utils.layers.build_layers.cooks import ( + can_set_cooks_layer, + set_cooks_layer, +) diff --git a/fedpydeseq2/core/utils/layers/build_layers/cooks.py b/fedpydeseq2/core/utils/layers/build_layers/cooks.py new file mode 100644 index 0000000..7d01b9c --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/cooks.py @@ -0,0 +1,137 @@ +"""Module to set the cooks layer.""" + + +import anndata as ad +import numpy as np + +from fedpydeseq2.core.utils.layers.build_layers.hat_diagonals import ( + can_set_hat_diagonals_layer, +) +from fedpydeseq2.core.utils.layers.build_layers.hat_diagonals import ( + set_hat_diagonals_layer, +) +from fedpydeseq2.core.utils.layers.build_layers.mu_layer import can_set_mu_layer +from fedpydeseq2.core.utils.layers.build_layers.mu_layer import set_mu_layer + + +def can_set_cooks_layer( + adata: ad.AnnData, shared_state: dict | None, raise_error: bool = False +) -> bool: + """Check if the Cook's distance can be set. + + Parameters + ---------- + adata : ad.AnnData + The local adata. + + shared_state : Optional[dict] + The shared state containing the Cook's dispersion values. + + raise_error : bool + Whether to raise an error if the Cook's distance cannot be set. + + Returns + ------- + bool: + Whether the Cook's distance can be set. + + Raises + ------ + ValueError: + If the Cook's distance cannot be set and raise_error is True. + + """ + if "cooks" in adata.layers.keys(): + return True + if shared_state is None: + if raise_error: + raise ValueError( + "To set cooks layer, there should be " "an input shared state" + ) + else: + return False + has_non_zero = "non_zero" in adata.varm.keys() + try: + has_hat_diagonals = can_set_hat_diagonals_layer( + adata, shared_state, raise_error + ) + except ValueError as hat_diagonals_error: + raise ValueError( + "The Cook's distance cannot be set because the hat diagonals cannot be set." + ) from hat_diagonals_error + try: + has_mu_LFC = can_set_mu_layer( + local_adata=adata, + lfc_param_name="LFC", + mu_param_name="_mu_LFC", + ) + except ValueError as mu_LFC_error: + raise ValueError( + "The Cook's distance cannot be set because the mu_LFC layer cannot be set." + ) from mu_LFC_error + has_X = adata.X is not None + has_cooks_dispersions = "cooks_dispersions" in shared_state.keys() + has_all = ( + has_non_zero + and has_hat_diagonals + and has_mu_LFC + and has_X + and has_cooks_dispersions + ) + if not has_all and raise_error: + raise ValueError( + "The Cook's distance cannot be set because " + "the following conditions are not met:" + f"\n- has_non_zero: {has_non_zero}" + f"\n- has_hat_diagonals: {has_hat_diagonals}" + f"\n- has_mu_LFC: {has_mu_LFC}" + f"\n- has_X: {has_X}" + f"\n- has_cooks_dispersions: {has_cooks_dispersions}" + ) + return has_all + + +def set_cooks_layer( + adata: ad.AnnData, + shared_state: dict | None, +): + """Compute the Cook's distance from the shared state. + + This function computes the Cook's distance from the shared state and stores it + in the "cooks" layer of the local adata. + + Parameters + ---------- + adata : ad.AnnData + The local adata. + + shared_state : dict + The shared state containing the Cook's dispersion values. + + """ + can_set_cooks_layer(adata, shared_state, raise_error=True) + if "cooks" in adata.layers.keys(): + return + # set all necessary layers + assert isinstance(shared_state, dict) + set_mu_layer(adata, lfc_param_name="LFC", mu_param_name="_mu_LFC") + set_hat_diagonals_layer(adata, shared_state) + num_vars = adata.uns["n_params"] + cooks_dispersions = shared_state["cooks_dispersions"] + V = ( + adata[:, adata.varm["non_zero"]].layers["_mu_LFC"] + + cooks_dispersions[None, adata.varm["non_zero"]] + * adata[:, adata.varm["non_zero"]].layers["_mu_LFC"] ** 2 + ) + squared_pearson_res = ( + adata[:, adata.varm["non_zero"]].X + - adata[:, adata.varm["non_zero"]].layers["_mu_LFC"] + ) ** 2 / V + diag_mul = ( + adata[:, adata.varm["non_zero"]].layers["_hat_diagonals"] + / (1 - adata[:, adata.varm["non_zero"]].layers["_hat_diagonals"]) ** 2 + ) + adata.layers["cooks"] = np.full((adata.n_obs, adata.n_vars), np.NaN) + adata.layers["cooks"][:, adata.varm["non_zero"]] = ( + squared_pearson_res / num_vars * diag_mul + ) diff --git a/fedpydeseq2/core/utils/layers/build_layers/fit_lin_mu_hat.py b/fedpydeseq2/core/utils/layers/build_layers/fit_lin_mu_hat.py new file mode 100644 index 0000000..754f9d8 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/fit_lin_mu_hat.py @@ -0,0 +1,77 @@ +"""Module to reconstruct the fit_lin_mu_hat layer.""" + +import anndata as ad +import numpy as np + +from fedpydeseq2.core.utils.layers.build_layers.y_hat import can_get_y_hat +from fedpydeseq2.core.utils.layers.build_layers.y_hat import set_y_hat + + +def can_get_fit_lin_mu_hat(local_adata: ad.AnnData, raise_error: bool = False) -> bool: + """Check if the fit_lin_mu_hat layer can be reconstructed. + + Parameters + ---------- + local_adata : ad.AnnData + The local AnnData object. + + raise_error : bool, optional + If True, raise an error if the fit_lin_mu_hat layer cannot be reconstructed. + + Returns + ------- + bool + True if the fit_lin_mu_hat layer can be reconstructed, False otherwise. + + Raises + ------ + ValueError + If the fit_lin_mu_hat layer cannot be reconstructed and raise_error is True. + + """ + if "_fit_lin_mu_hat" in local_adata.layers.keys(): + return True + try: + y_hat_ok = can_get_y_hat(local_adata, raise_error=raise_error) + except ValueError as y_hat_error: + raise ValueError( + f"Error while checking if y_hat can be reconstructed: {y_hat_error}" + ) from y_hat_error + + has_size_factors = "size_factors" in local_adata.obsm.keys() + has_non_zero = "non_zero" in local_adata.varm.keys() + if not has_size_factors or not has_non_zero: + if raise_error: + raise ValueError( + "Local adata must contain the size_factors obsm " + "and the non_zero varm to compute the fit_lin_mu_hat layer." + " Here are the keys present in the local adata: " + f"obsm : {local_adata.obsm.keys()} and varm : {local_adata.varm.keys()}" + ) + return False + return y_hat_ok + + +def set_fit_lin_mu_hat(local_adata: ad.AnnData, min_mu: float = 0.5): + """ + Calculate the _fit_lin_mu_hat layer using the provided local data. + + Checks are performed to ensure necessary keys are present in the data. + + Parameters + ---------- + local_adata : ad.AnnData + The local anndata object containing necessary keys for computation. + min_mu : float, optional + The minimum value for mu, defaults to 0.5. + + """ + can_get_fit_lin_mu_hat(local_adata, raise_error=True) + if "_fit_lin_mu_hat" in local_adata.layers.keys(): + return + set_y_hat(local_adata) + mu_hat = local_adata.obsm["size_factors"][:, None] * local_adata.layers["_y_hat"] + fit_lin_mu_hat = np.maximum(mu_hat, min_mu) + + fit_lin_mu_hat[:, ~local_adata.varm["non_zero"]] = np.nan + local_adata.layers["_fit_lin_mu_hat"] = fit_lin_mu_hat diff --git a/fedpydeseq2/core/utils/layers/build_layers/hat_diagonals.py b/fedpydeseq2/core/utils/layers/build_layers/hat_diagonals.py new file mode 100644 index 0000000..0595c7e --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/hat_diagonals.py @@ -0,0 +1,216 @@ +"""Module to set the hat diagonals layer.""" + +import anndata as ad +import numpy as np +from joblib import Parallel +from joblib import delayed +from joblib import parallel_backend + + +def can_set_hat_diagonals_layer( + adata: ad.AnnData, shared_state: dict | None, raise_error: bool = False +) -> bool: + """Check if the hat diagonals layer can be reconstructed. + + Parameters + ---------- + adata : ad.AnnData + The AnnData object. + + shared_state : Optional[dict] + The shared state dictionary. + + raise_error : bool, optional + If True, raise an error if the hat diagonals layer cannot be reconstructed. + + Returns + ------- + bool + True if the hat diagonals layer can be reconstructed, False otherwise. + + Raises + ------ + ValueError + If the hat diagonals layer cannot be reconstructed and raise_error is True. + + """ + if "_hat_diagonals" in adata.layers.keys(): + return True + + if shared_state is None: + if raise_error: + raise ValueError( + "To set the _hat_diagonals layer, there" "should be a shared state." + ) + else: + return False + + has_design_matrix = "design_matrix" in adata.obsm.keys() + has_lfc_param = "LFC" in adata.varm.keys() + has_size_factors = "size_factors" in adata.obsm.keys() + has_non_zero = "non_zero" in adata.varm.keys() + has_dispersion = "dispersions" in adata.varm.keys() + has_global_hat_matrix_inv = "global_hat_matrix_inv" in shared_state.keys() + + has_all = ( + has_design_matrix + and has_lfc_param + and has_size_factors + and has_non_zero + and has_global_hat_matrix_inv + and has_dispersion + ) + if not has_all: + if raise_error: + raise ValueError( + "Adata must contain the design matrix obsm" + ", the LFC varm, the dispersions varm, " + "the size_factors obsm, the non_zero varm " + "and the global_hat_matrix_inv " + "in the shared state to compute the hat diagonals layer." + " Here are the keys present in the adata: " + f"obsm : {adata.obsm.keys()} and varm : {adata.varm.keys()}, and the " + f"shared state keys: {shared_state.keys()}" + ) + return False + return True + + +def set_hat_diagonals_layer( + adata: ad.AnnData, + shared_state: dict | None, + n_jobs: int = 1, + joblib_verbosity: int = 0, + joblib_backend: str = "loky", + batch_size: int = 100, + min_mu: float = 0.5, +): + """ + Compute the hat diagonals layer from the adata and the shared state. + + Parameters + ---------- + adata : ad.AnnData + The AnnData object. + + shared_state : Optional[dict] + The shared state dictionary. + This dictionary must contain the global hat matrix inverse. + + n_jobs : int + The number of jobs to use for parallel processing. + + joblib_verbosity : int + The verbosity level of joblib. + + joblib_backend : str + The joblib backend to use. + + batch_size : int + The batch size for parallel processing. + + min_mu : float + Lower bound on estimated means, to ensure numerical stability. + + Returns + ------- + np.ndarray + The hat diagonals layer, of shape (n_obs, n_params). + + """ + can_set_hat_diagonals_layer(adata, shared_state, raise_error=True) + if "_hat_diagonals" in adata.layers.keys(): + return + + assert shared_state is not None, ( + "To construct the _hat_diagonals layer, " "one must have a shared state." + ) + + gene_names = adata.var_names[adata.varm["non_zero"]] + beta = adata.varm["LFC"].loc[gene_names].to_numpy() + design_matrix = adata.obsm["design_matrix"].values + size_factors = adata.obsm["size_factors"] + + dispersions = adata[:, gene_names].varm["dispersions"] + + # ---- Step 1: Compute the mu and the diagonal of the hat matrix ---- # + + with parallel_backend(joblib_backend): + res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)( + delayed(make_hat_diag_batch)( + beta[i : i + batch_size], + shared_state["global_hat_matrix_inv"][i : i + batch_size], + design_matrix, + size_factors, + dispersions[i : i + batch_size], + min_mu, + ) + for i in range(0, len(beta), batch_size) + ) + + H = np.concatenate(res) + + H_layer = np.full(adata.shape, np.NaN) + + H_layer[:, adata.var_names.get_indexer(gene_names)] = H.T + + adata.layers["_hat_diagonals"] = H_layer + + +def make_hat_diag_batch( + beta: np.ndarray, + global_hat_matrix_inv: np.ndarray, + design_matrix: np.ndarray, + size_factors: np.ndarray, + dispersions: np.ndarray, + min_mu: float = 0.5, +) -> np.ndarray: + """ + Compute the H matrix for a batch of LFC estimates. + + Parameters + ---------- + beta : np.ndarray + Current LFC estimate, of shape (batch_size, n_params). + global_hat_matrix_inv : np.ndarray + The inverse of the global hat matrix, of shape (batch_size, n_params, n_params). + design_matrix : np.ndarray + The design matrix, of shape (n_obs, n_params). + size_factors : np.ndarray + The size factors, of shape (n_obs). + dispersions : np.ndarray + The dispersions, of shape (batch_size). + min_mu : float + Lower bound on estimated means, to ensure numerical stability. + (default: ``0.5``). + + Returns + ------- + np.ndarray + The H matrix, of shape (batch_size, n_obs). + + """ + mu = size_factors[:, None] * np.exp(design_matrix @ beta.T) + mu_clipped = np.maximum( + mu, + min_mu, + ) + + # W of shape (n_obs, batch_size) + W = mu_clipped / (1.0 + mu_clipped * dispersions[None, :]) + + # W_sq Of shape (batch_size, n_obs) + W_sq = np.sqrt(W).T + + # Inside the diagonal operator is of shape (batch_size, n_obs, n_obs) + # The diagonal operator takes the diagonal per gene in the batch + # H is therefore of shape (batch_size, n_obs) + H = np.diagonal( + design_matrix @ global_hat_matrix_inv @ design_matrix.T, + axis1=1, + axis2=2, + ) + + H = W_sq * H * W_sq + + return H diff --git a/fedpydeseq2/core/utils/layers/build_layers/mu_hat.py b/fedpydeseq2/core/utils/layers/build_layers/mu_hat.py new file mode 100644 index 0000000..5e23fb0 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/mu_hat.py @@ -0,0 +1,103 @@ +"""Module to build the mu_hat layer.""" + +import anndata as ad + +from fedpydeseq2.core.utils.layers.build_layers.fit_lin_mu_hat import ( + can_get_fit_lin_mu_hat, +) +from fedpydeseq2.core.utils.layers.build_layers.fit_lin_mu_hat import set_fit_lin_mu_hat +from fedpydeseq2.core.utils.layers.build_layers.mu_layer import can_set_mu_layer +from fedpydeseq2.core.utils.layers.build_layers.mu_layer import set_mu_layer + + +def can_get_mu_hat(local_adata: ad.AnnData, raise_error: bool = False) -> bool: + """Check if the mu_hat layer can be reconstructed. + + Parameters + ---------- + local_adata : ad.AnnData + The local AnnData object. + + raise_error : bool, optional + If True, raise an error if the mu_hat layer cannot be reconstructed. + + Returns + ------- + bool + True if the mu_hat layer can be reconstructed, False otherwise. + + Raises + ------ + ValueError + If the mu_hat layer cannot be reconstructed and raise_error is True. + + """ + if "_mu_hat" in local_adata.layers.keys(): + return True + has_num_replicates = "num_replicates" in local_adata.uns + has_n_params = "n_params" in local_adata.uns + if not has_num_replicates or not has_n_params: + if raise_error: + raise ValueError( + "Local adata must contain num_replicates in uns field " + "and n_params in uns field to compute mu_hat." + " Here are the keys present in the local adata: " + f"uns : {local_adata.uns.keys()}" + ) + return False + # If the number of replicates is not equal to the number of parameters, + # we need to reconstruct mu_hat from the adata. + if len(local_adata.uns["num_replicates"]) != local_adata.uns["n_params"]: + try: + mu_hat_LFC_ok = can_set_mu_layer( + local_adata=local_adata, + lfc_param_name="_mu_hat_LFC", + mu_param_name="_irls_mu_hat", + raise_error=raise_error, + ) + except ValueError as mu_hat_LFC_error: + raise ValueError( + "Error while checking if mu_hat_LFC can " + f"be reconstructed: {mu_hat_LFC_error}" + ) from mu_hat_LFC_error + return mu_hat_LFC_ok + else: + try: + fit_lin_mu_hat_ok = can_get_fit_lin_mu_hat( + local_adata=local_adata, + raise_error=raise_error, + ) + except ValueError as fit_lin_mu_hat_error: + raise ValueError( + "Error while checking if fit_lin_mu_hat can be " + f"reconstructed: {fit_lin_mu_hat_error}" + ) from fit_lin_mu_hat_error + return fit_lin_mu_hat_ok + + +def set_mu_hat_layer(local_adata: ad.AnnData): + """ + Reconstruct the mu_hat layer. + + Parameters + ---------- + local_adata: ad.AnnData + The local AnnData object. + + """ + can_get_mu_hat(local_adata, raise_error=True) + if "_mu_hat" in local_adata.layers.keys(): + return + + if len(local_adata.uns["num_replicates"]) != local_adata.uns["n_params"]: + set_mu_layer( + local_adata=local_adata, + lfc_param_name="_mu_hat_LFC", + mu_param_name="_irls_mu_hat", + ) + local_adata.layers["_mu_hat"] = local_adata.layers["_irls_mu_hat"].copy() + return + set_fit_lin_mu_hat( + local_adata=local_adata, + ) + local_adata.layers["_mu_hat"] = local_adata.layers["_fit_lin_mu_hat"].copy() diff --git a/fedpydeseq2/core/utils/layers/build_layers/mu_layer.py b/fedpydeseq2/core/utils/layers/build_layers/mu_layer.py new file mode 100644 index 0000000..d71dce1 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/mu_layer.py @@ -0,0 +1,157 @@ +"""Module to construct mu layer from LFC estimates.""" + +import anndata as ad +import numpy as np +from joblib import Parallel +from joblib import delayed +from joblib import parallel_backend + + +def can_set_mu_layer( + local_adata: ad.AnnData, + lfc_param_name: str, + mu_param_name: str, + raise_error: bool = False, +) -> bool: + """Check if the mu layer can be reconstructed. + + Parameters + ---------- + local_adata : ad.AnnData + The local AnnData object. + + lfc_param_name : str + The name of the log fold changes parameter in the adata. + + mu_param_name : str + The name of the mu parameter in the adata. + + raise_error : bool, optional + If True, raise an error if the mu layer cannot be reconstructed. + + Returns + ------- + bool + True if the mu layer can be reconstructed, False otherwise. + + Raises + ------ + ValueError + If the mu layer cannot be reconstructed and raise_error is True. + + """ + if mu_param_name in local_adata.layers.keys(): + return True + + has_design_matrix = "design_matrix" in local_adata.obsm.keys() + has_lfc_param = lfc_param_name in local_adata.varm.keys() + has_size_factors = "size_factors" in local_adata.obsm.keys() + has_non_zero = "non_zero" in local_adata.varm.keys() + + has_all = has_design_matrix and has_lfc_param and has_size_factors and has_non_zero + if not has_all: + if raise_error: + raise ValueError( + "Local adata must contain the design matrix obsm" + f", the {lfc_param_name} varm to compute the mu layer, " + f"the size_factors obsm and the non_zero varm. " + " Here are the keys present in the local adata: " + f"obsm : {local_adata.obsm.keys()} and varm : {local_adata.varm.keys()}" + ) + return False + return True + + +def set_mu_layer( + local_adata: ad.AnnData, + lfc_param_name: str, + mu_param_name: str, + n_jobs: int = 1, + joblib_verbosity: int = 0, + joblib_backend: str = "loky", + batch_size: int = 100, +): + """Reconstruct a mu layer from the adata and a given LFC field. + + Parameters + ---------- + local_adata : ad.AnnData + The local AnnData object. + + lfc_param_name : str + The name of the log fold changes parameter in the adata. + + mu_param_name : str + The name of the mu parameter in the adata. + + n_jobs : int + Number of jobs to run in parallel. + + joblib_verbosity : int + Verbosity level of joblib. + + joblib_backend : str + Joblib backend to use. + + batch_size : int + Batch size for parallelization. + + """ + can_set_mu_layer( + local_adata, lfc_param_name, mu_param_name=mu_param_name, raise_error=True + ) + if mu_param_name in local_adata.layers.keys(): + return + gene_names = local_adata.var_names[local_adata.varm["non_zero"]] + beta = local_adata.varm[lfc_param_name].loc[gene_names].to_numpy() + design_matrix = local_adata.obsm["design_matrix"].values + size_factors = local_adata.obsm["size_factors"] + + with parallel_backend(joblib_backend): + res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)( + delayed(make_mu_batch)( + beta[i : i + batch_size], + design_matrix, + size_factors, + ) + for i in range(0, len(beta), batch_size) + ) + + if len(res) == 0: + mu = np.zeros((local_adata.shape[0], 0)) + else: + mu = np.concatenate(list(res), axis=1) + + mu_layer = np.full(local_adata.shape, np.NaN) + + mu_layer[:, local_adata.var_names.get_indexer(gene_names)] = mu + + local_adata.layers[mu_param_name] = mu_layer + + +def make_mu_batch( + beta: np.ndarray, + design_matrix: np.ndarray, + size_factors: np.ndarray, +) -> np.ndarray: + """ + Compute the mu matrix for a batch of LFC estimates. + + Parameters + ---------- + beta : np.ndarray + Current LFC estimate, of shape (batch_size, n_params). + design_matrix : np.ndarray + The design matrix, of shape (n_obs, n_params). + size_factors : np.ndarray + The size factors, of shape (n_obs). + + Returns + ------- + mu : np.ndarray + The mu matrix, of shape (n_obs, batch_size). + + """ + mu = size_factors[:, None] * np.exp(design_matrix @ beta.T) + + return mu diff --git a/fedpydeseq2/core/utils/layers/build_layers/normed_counts.py b/fedpydeseq2/core/utils/layers/build_layers/normed_counts.py new file mode 100644 index 0000000..a69951d --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/normed_counts.py @@ -0,0 +1,56 @@ +"""Module to construct the normed_counts layer.""" + +import anndata as ad + + +def can_get_normed_counts(adata: ad.AnnData, raise_error: bool = False) -> bool: + """Check if the normed_counts layer can be reconstructed. + + Parameters + ---------- + adata : ad.AnnData + The local AnnData object. + + raise_error : bool, optional + If True, raise an error if the normed_counts layer cannot be reconstructed. + + Returns + ------- + bool + True if the normed_counts layer can be reconstructed, False otherwise. + + Raises + ------ + ValueError + If the normed_counts layer cannot be reconstructed and raise_error is True. + + """ + if "normed_counts" in adata.layers.keys(): + return True + has_X = adata.X is not None + has_size_factors = "size_factors" in adata.obsm.keys() + if not has_X or not has_size_factors: + if raise_error: + raise ValueError( + "Local adata must contain the X field " + "and the size_factors obsm to compute the normed_counts layer." + " Here are the keys present in the adata: " + f" obsm : {adata.obsm.keys()}" + ) + return False + return True + + +def set_normed_counts(adata: ad.AnnData): + """Reconstruct the normed_counts layer. + + Parameters + ---------- + adata : ad.AnnData + The local AnnData object. + + """ + can_get_normed_counts(adata, raise_error=True) + if "normed_counts" in adata.layers.keys(): + return + adata.layers["normed_counts"] = adata.X / adata.obsm["size_factors"][:, None] diff --git a/fedpydeseq2/core/utils/layers/build_layers/sqerror.py b/fedpydeseq2/core/utils/layers/build_layers/sqerror.py new file mode 100644 index 0000000..74dc651 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/sqerror.py @@ -0,0 +1,91 @@ +"""Module to construct the sqerror layer.""" + +import anndata as ad +import numpy as np +import pandas as pd + +from fedpydeseq2.core.utils.layers.build_layers.normed_counts import ( + can_get_normed_counts, +) +from fedpydeseq2.core.utils.layers.build_layers.normed_counts import set_normed_counts + + +def can_get_sqerror_layer(adata: ad.AnnData, raise_error: bool = False) -> bool: + """Check if the squared error layer can be reconstructed. + + Parameters + ---------- + adata : ad.AnnData + The local AnnData object. + + raise_error : bool, optional + If True, raise an error if the squared error layer cannot be reconstructed. + + Returns + ------- + bool + True if the squared error layer can be reconstructed, False otherwise. + + Raises + ------ + ValueError + If the squared error layer cannot be reconstructed and raise_error is True. + + """ + if "sqerror" in adata.layers.keys(): + return True + try: + has_normed_counts = can_get_normed_counts(adata, raise_error=raise_error) + except ValueError as normed_counts_error: + raise ValueError( + f"Error while checking if normed_counts can be" + f" reconstructed: {normed_counts_error}" + ) from normed_counts_error + + has_cell_means = "cell_means" in adata.varm.keys() + has_cell_obs = "cells" in adata.obs.keys() + if not has_normed_counts or not has_cell_means or not has_cell_obs: + if raise_error: + raise ValueError( + "Local adata must contain the normed_counts layer, the cells obs, " + "and the cell_means varm to compute the squared error layer." + " Here are the keys present in the adata: " + f"obs : {adata.obs.keys()}, varm : {adata.varm.keys()}" + ) + return False + return True + + +def set_sqerror_layer(local_adata: ad.AnnData): + """Compute the squared error between the normalized counts and the trimmed mean. + + Parameters + ---------- + local_adata : ad.AnnData + Local AnnData. It is expected to have the following fields: + - layers["normed_counts"]: the normalized counts. + - varm["cell_means"]: the trimmed mean. + - obs["cells"]: the cells. + + """ + can_get_sqerror_layer(local_adata, raise_error=True) + if "sqerror" in local_adata.layers.keys(): + return + cell_means = local_adata.varm["cell_means"] + set_normed_counts(local_adata) + if isinstance(cell_means, pd.DataFrame): + cells = local_adata.obs["cells"] + # restrict to the cells that are in the cell means columns + cells = cells[cells.isin(cell_means.columns)] + qmat = cell_means[cells].T + qmat.index = cells.index + + # initialize wiht nans + layer = np.full_like(local_adata.layers["normed_counts"], np.nan) + indices = local_adata.obs_names.get_indexer(qmat.index) + layer[indices, :] = ( + local_adata[qmat.index, :].layers["normed_counts"] - qmat + ) ** 2 + else: + layer = (local_adata.layers["normed_counts"] - cell_means[None, :]) ** 2 + local_adata.layers["sqerror"] = layer diff --git a/fedpydeseq2/core/utils/layers/build_layers/y_hat.py b/fedpydeseq2/core/utils/layers/build_layers/y_hat.py new file mode 100644 index 0000000..747102c --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_layers/y_hat.py @@ -0,0 +1,60 @@ +"""Module containing the necessary functions to reconstruct the y_hat layer.""" + +import anndata as ad + + +def can_get_y_hat(local_adata: ad.AnnData, raise_error: bool = False) -> bool: + """Check if the y_hat layer can be reconstructed. + + Parameters + ---------- + local_adata : ad.AnnData + The local AnnData object. + + raise_error : bool, optional + If True, raise an error if the y_hat layer cannot be reconstructed. + + Returns + ------- + bool + True if the y_hat layer can be reconstructed, False otherwise. + + Raises + ------ + ValueError + If the y_hat layer cannot be reconstructed and raise_error is True. + + """ + if "_y_hat" in local_adata.layers.keys(): + return True + has_design_matrix = "design_matrix" in local_adata.obsm.keys() + has_beta_rough_dispersions = "_beta_rough_dispersions" in local_adata.varm.keys() + if not has_design_matrix or not has_beta_rough_dispersions: + if raise_error: + raise ValueError( + "Local adata must contain the design matrix obsm " + "and the _beta_rough_dispersions varm to compute the y_hat layer." + " Here are the keys present in the local adata: " + f"obsm : {local_adata.obsm.keys()} and varm : {local_adata.varm.keys()}" + ) + return False + return True + + +def set_y_hat(local_adata: ad.AnnData): + """Reconstruct the y_hat layer. + + Parameters + ---------- + local_adata : ad.AnnData + The local AnnData object. + + """ + can_get_y_hat(local_adata, raise_error=True) + if "_y_hat" in local_adata.layers.keys(): + return + y_hat = ( + local_adata.obsm["design_matrix"].to_numpy() + @ local_adata.varm["_beta_rough_dispersions"].T + ) + local_adata.layers["_y_hat"] = y_hat diff --git a/fedpydeseq2/core/utils/layers/build_refit_adata.py b/fedpydeseq2/core/utils/layers/build_refit_adata.py new file mode 100644 index 0000000..ffdc4b5 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/build_refit_adata.py @@ -0,0 +1,107 @@ +from typing import Any + +import numpy as np +import pandas as pd + + +def set_basic_refit_adata(self: Any): + """Set the basic refit adata from the local adata. + + This function checks that the local adata is loaded and the replaced + genes are computed and stored in the varm field. It then sets the refit + adata from the local adata. + + Parameters + ---------- + self : Any + The object containing the local adata and the refit adata. + + """ + assert ( + self.local_adata is not None + ), "Local adata must be loaded before setting the refit adata." + assert ( + "replaced" in self.local_adata.varm.keys() + ), "Replaced genes must be computed before setting the refit adata." + + genes_to_replace = pd.Series( + self.local_adata.varm["replaced"], index=self.local_adata.var_names + ) + if self.refit_adata is None: + self.refit_adata = self.local_adata[:, genes_to_replace].copy() + # Clear the varm field of the refit adata + self.refit_adata.varm = None + elif "refitted" not in self.local_adata.varm.keys(): + self.refit_adata.X = self.local_adata[:, genes_to_replace].X.copy() + self.refit_adata.obsm = self.local_adata.obsm + else: + genes_to_refit = pd.Series( + self.local_adata.varm["refitted"], index=self.local_adata.var_names + ) + self.refit_adata.X = self.local_adata[:, genes_to_refit].X.copy() + self.refit_adata.obsm = self.local_adata.obsm + + +def set_imputed_counts_refit_adata(self: Any): + """Set the imputed counts in the refit adata. + + This function checks that the refit adata, the local adata, the replaced + genes, the trimmed mean normed counts, the size factors, the cooks G cutoff, + and the replaceable genes are computed and stored in the appropriate fields. + It then sets the imputed counts in the refit adata. + + Note that this function must be run on an object which already contains + a refit_adata, whose counts, obsm and uns have been set with the + `set_basic_refit_adata` function. + + Parameters + ---------- + self : Any + The object containing the refit adata, the local adata, the replaced + genes, the trimmed mean normed counts, the size factors, the cooks G + cutoff, and the replaceable genes. + + """ + assert ( + self.refit_adata is not None + ), "Refit adata must be loaded before setting the imputed counts." + assert ( + self.local_adata is not None + ), "Local adata must be loaded before setting the imputed counts." + assert ( + "replaced" in self.local_adata.varm.keys() + ), "Replaced genes must be computed before setting the imputed counts." + assert ( + "_trimmed_mean_normed_counts" in self.refit_adata.varm.keys() + ), "Trimmed mean normed counts must be computed before setting the imputed counts." + assert ( + "size_factors" in self.refit_adata.obsm.keys() + ), "Size factors must be computed before setting the imputed counts." + assert ( + "_where_cooks_g_cutoff" in self.local_adata.uns.keys() + ), "Cooks G cutoff must be computed before setting the imputed counts." + assert ( + "replaceable" in self.refit_adata.obsm.keys() + ), "Replaceable genes must be computed before setting the imputed counts." + + trimmed_mean_normed_counts = self.refit_adata.varm["_trimmed_mean_normed_counts"] + + replacement_counts = pd.DataFrame( + self.refit_adata.obsm["size_factors"][:, None] * trimmed_mean_normed_counts, + columns=self.refit_adata.var_names, + index=self.refit_adata.obs_names, + ).astype(int) + + idx = np.zeros(self.local_adata.shape, dtype=bool) + idx[self.local_adata.uns["_where_cooks_g_cutoff"]] = True + + # Restrict to the genes to replace + if "refitted" not in self.local_adata.varm.keys(): + idx = idx[:, self.local_adata.varm["replaced"]] + else: + idx = idx[:, self.local_adata.varm["refitted"]] + + # Replace the counts + self.refit_adata.X[ + self.refit_adata.obsm["replaceable"][:, None] & idx + ] = replacement_counts.values[self.refit_adata.obsm["replaceable"][:, None] & idx] diff --git a/fedpydeseq2/core/utils/layers/cooks_layer.py b/fedpydeseq2/core/utils/layers/cooks_layer.py new file mode 100644 index 0000000..c299b19 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/cooks_layer.py @@ -0,0 +1,328 @@ +from collections.abc import Callable +from functools import wraps +from typing import Any +from typing import cast + +import anndata as ad +import numpy as np +import pandas as pd +from joblib import Parallel +from joblib import delayed +from joblib import parallel_backend + +from fedpydeseq2.core.utils.compute_lfc_utils import get_lfc_utils_from_gene_mask_adata +from fedpydeseq2.core.utils.layers.joblib_utils import get_joblib_parameters + + +def prepare_cooks_local(method: Callable): + """Decorate the local method just preceding a local method needing cooks. + + This method is only applied if the Cooks layer is not present or must not be + saved between steps. + + This step is used to compute the local hat matrix and the mean normed counts. + + Before the method is called, the varEst must be accessed from the shared state, + or from the local adata if it is not present in the shared state. + + The local hat matrix and the mean normed counts are computed, and the following + keys are added to the shared state: + - local_hat_matrix + - mean_normed_counts + - n_samples + - varEst + + Parameters + ---------- + method : Callable + The remote_data method to decorate. + + Returns + ------- + Callable: + The decorated method. + + """ + + @wraps(method) + def method_inner( + self, + data_from_opener: ad.AnnData, + shared_state: Any = None, + **method_parameters, + ): + # ---- Step 0: If can skip, we skip ---- # + if can_skip_local_cooks_preparation(self): + shared_state = method( + self, data_from_opener, shared_state, **method_parameters + ) + shared_state["_skip_cooks"] = True + return shared_state + + # ---- Step 1: Access varEst ---- # + + if "varEst" in self.local_adata.varm.keys(): + varEst = self.local_adata.varm["varEst"] + else: + assert "varEst" in shared_state + varEst = shared_state["varEst"] + self.local_adata.varm["varEst"] = varEst + + # ---- Step 2: Run the method ---- # + shared_state = method(self, data_from_opener, shared_state, **method_parameters) + + # ---- Step 3: Compute the local hat matrix ---- # + + n_jobs, joblib_verbosity, joblib_backend, batch_size = get_joblib_parameters( + self + ) + # Compute hat matrix + ( + gene_names, + design_matrix, + size_factors, + counts, + dispersions, + beta, + ) = get_lfc_utils_from_gene_mask_adata( + self.local_adata, + None, + "dispersions", + lfc_param_name="LFC", + ) + + with parallel_backend(joblib_backend): + res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)( + delayed(make_hat_matrix_summands_batch)( + design_matrix, + size_factors, + beta[i : i + batch_size], + dispersions[i : i + batch_size], + self.min_mu, + ) + for i in range(0, len(beta), batch_size) + ) + + if len(res) == 0: + H = np.zeros((0, beta.shape[1], beta.shape[1])) + else: + H = np.concatenate(res) + + shared_state["local_hat_matrix"] = H + + # ---- Step 4: Compute the mean normed counts ---- # + + mean_normed_counts = self.local_adata.layers["normed_counts"].mean(axis=0) + + shared_state["mean_normed_counts"] = mean_normed_counts + shared_state["n_samples"] = self.local_adata.n_obs + shared_state["varEst"] = varEst + shared_state["_skip_cooks"] = False + + return shared_state + + return method_inner + + +def prepare_cooks_agg(method: Callable): + """Decorate the aggregation step to compute the Cook's distance. + + This decorator is supposed to be placed on the aggregation step just before + a local step which needs the "cooks" layer. The decorator will check if the + shared state contains the necessary keys for the Cook's distance computation. + If this is not the case, then the Cook's distance must have been saved in the + layers_to_save. + It will compute the Cook's dispersion, the hat matrix inverse, and then call + the method. + + It will add the following keys to the shared state: + - cooks_dispersions + - global_hat_matrix_inv + + Parameters + ---------- + method : Callable + The aggregation method to decorate. + It must have the following signature: + method(self, shared_states: Optional[list], **method_parameters). + + Returns + ------- + Callable: + The decorated method. + + """ + + @wraps(method) + def method_inner( + self, + shared_states: list | None, + **method_parameters, + ): + # Check that the shared state contains the necessary keys + # for the Cook's distance computation + # If this is not the case, then the cooks distance must have + # been saved in the layers_to_save + + try: + assert isinstance(shared_states, list) + assert "n_samples" in shared_states[0].keys() + assert "varEst" in shared_states[0].keys() + assert "mean_normed_counts" in shared_states[0].keys() + assert "local_hat_matrix" in shared_states[0].keys() + except AssertionError as assertion_error: + only_from_disk = ( + not hasattr(self, "save_layers_to_disk") or self.save_layers_to_disk + ) + if only_from_disk: + return method(self, shared_states, **method_parameters) + elif isinstance(shared_states, list) and shared_states[0]["_skip_cooks"]: + return method(self, shared_states, **method_parameters) + raise ValueError( + "The shared state does not contain the necessary keys for" + "the Cook's distance computation." + ) from assertion_error + + assert isinstance(shared_states, list) + + # ---- Step 1: Compute Cooks dispersion ---- # + + n_sample_tot = sum( + [shared_state["n_samples"] for shared_state in shared_states] + ) + varEst = shared_states[0]["varEst"] + mean_normed_counts = ( + np.array( + [ + (shared_state["mean_normed_counts"] * shared_state["n_samples"]) + for shared_state in shared_states + ] + ).sum(axis=0) + / n_sample_tot + ) + mask_zero = mean_normed_counts == 0 + mask_varEst_zero = varEst == 0 + alpha = varEst - mean_normed_counts + alpha[~mask_zero] = alpha[~mask_zero] / mean_normed_counts[~mask_zero] ** 2 + alpha[mask_varEst_zero & mask_zero] = np.nan + alpha[mask_varEst_zero & (~mask_zero)] = ( + np.inf * alpha[mask_varEst_zero & (~mask_zero)] + ) + + # cannot use the typical min_disp = 1e-8 here or else all counts in the same + # group as the outlier count will get an extreme Cook's distance + minDisp = 0.04 + alpha = cast(pd.Series, np.maximum(alpha, minDisp)) + + # --- Step 2: Compute the hat matrix inverse --- # + + global_hat_matrix = sum([state["local_hat_matrix"] for state in shared_states]) + n_jobs, joblib_verbosity, joblib_backend, batch_size = get_joblib_parameters( + self + ) + ridge_factor = np.diag(np.repeat(1e-6, global_hat_matrix.shape[1])) + with parallel_backend(joblib_backend): + res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)( + delayed(np.linalg.inv)(hat_matrices + ridge_factor) + for hat_matrices in np.split( + global_hat_matrix, + range( + batch_size, + len(global_hat_matrix), + batch_size, + ), + ) + ) + + global_hat_matrix_inv = np.concatenate(res) + + # ---- Step 3: Run the method ---- # + + shared_state = method(self, shared_states, **method_parameters) + + # ---- Step 4: Save the Cook's dispersion and the hat matrix inverse ---- # + + shared_state["cooks_dispersions"] = alpha + shared_state["global_hat_matrix_inv"] = global_hat_matrix_inv + + return shared_state + + return method_inner + + +def can_skip_local_cooks_preparation(self: Any) -> bool: + """Check if the Cook's distance is in the layers to save. + + This function checks if the Cook's distance is in the layers to save. + + Parameters + ---------- + self : Any + The object. + + Returns + ------- + bool: + Whether the Cook's distance is in the layers to save. + + """ + only_from_disk = ( + not hasattr(self, "save_layers_to_disk") or self.save_layers_to_disk + ) + if only_from_disk and "cooks" in self.local_adata.layers.keys(): + return True + if hasattr(self, "layers_to_save_on_disk"): + layers_to_save_on_disk = self.layers_to_save_on_disk + if ( + layers_to_save_on_disk is not None + and "local_adata" in layers_to_save_on_disk + and layers_to_save_on_disk["local_adata"] is not None + and "cooks" in layers_to_save_on_disk["local_adata"] + ): + return True + return False + + +def make_hat_matrix_summands_batch( + design_matrix: np.ndarray, + size_factors: np.ndarray, + beta: np.ndarray, + dispersions: np.ndarray, + min_mu: float, +) -> np.ndarray: + """Make the local hat matrix. + + This is quite similar to the make_irls_summands_batch function, but it does not + require the counts, and returns only the H matrix. + + This is used in the final step of the IRLS algorithm to compute the local hat + matrix. + + Parameters + ---------- + design_matrix : np.ndarray + The design matrix, of shape (n_obs, n_params). + size_factors : np.ndarray + The size factors, of shape (n_obs). + beta : np.ndarray + The log fold change matrix, of shape (batch_size, n_params). + dispersions : np.ndarray + The dispersions, of shape (batch_size). + min_mu : float + Lower bound on estimated means, to ensure numerical stability. + + + Returns + ------- + H : np.ndarray + The H matrix, of shape (batch_size, n_params, n_params). + """ + mu = size_factors[:, None] * np.exp(design_matrix @ beta.T) + + mu = np.maximum(mu, min_mu) + + W = mu / (1.0 + mu * dispersions[None, :]) + + H = (design_matrix.T[:, :, None] * W).transpose(2, 0, 1) @ design_matrix[None, :, :] + + return H diff --git a/fedpydeseq2/core/utils/layers/joblib_utils.py b/fedpydeseq2/core/utils/layers/joblib_utils.py new file mode 100644 index 0000000..b752cd9 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/joblib_utils.py @@ -0,0 +1,32 @@ +from typing import Any + + +def get_joblib_parameters(x: Any) -> tuple[int, int, str, int]: + """ + Get the joblib parameters from an object, and return them as a tuple. + + If the object has no joblib parameters, default values are returned. + + Parameters + ---------- + x: Any + Object from which to extract the joblib parameters. + + Returns + ------- + n_jobs: int + Number of jobs to run in parallel. + joblib_verbosity: int + Verbosity level of joblib. + joblib_backend: str + Joblib backend. + batch_size: int + Batch size for the IRLS algorithm. + + """ + n_jobs = x.num_jobs if hasattr(x, "num_jobs") else 1 + + joblib_verbosity = x.joblib_verbosity if hasattr(x, "joblib_verbosity") else 0 + joblib_backend = x.joblib_backend if hasattr(x, "joblib_backend") else "loky" + batch_size = x.irls_batch_size if hasattr(x, "irls_batch_size") else 100 + return n_jobs, joblib_verbosity, joblib_backend, batch_size diff --git a/fedpydeseq2/core/utils/layers/reconstruct_adatas_decorator.py b/fedpydeseq2/core/utils/layers/reconstruct_adatas_decorator.py new file mode 100644 index 0000000..f1084f7 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/reconstruct_adatas_decorator.py @@ -0,0 +1,250 @@ +"""Module containing a decorator to handle simple layers. + +This wrapper is used to load and save simple layers from the adata object. +These simple layers are defined in SIMPLE_LAYERS. + + +""" + +from collections.abc import Callable +from functools import wraps +from typing import Any + +import anndata as ad +import numpy as np + +from fedpydeseq2.core.utils.layers.build_refit_adata import set_basic_refit_adata +from fedpydeseq2.core.utils.layers.build_refit_adata import ( + set_imputed_counts_refit_adata, +) +from fedpydeseq2.core.utils.layers.joblib_utils import get_joblib_parameters +from fedpydeseq2.core.utils.layers.utils import get_available_layers +from fedpydeseq2.core.utils.layers.utils import load_layers +from fedpydeseq2.core.utils.layers.utils import remove_layers + +LayersToLoadSaveType = dict[str, list[str] | None] | None + + +def reconstruct_adatas(method: Callable): + """Decorate a method to load layers and remove them before saving the state. + + This decorator loads the layers from the data_from_opener and the adata + object before calling the method. It then removes the layers from the adata + object after the method is called. + + The object self CAN have the following attributes: + + - save_layers_to_disk: if this argument exists or is True, we save all the layers + on disk, without removing them at the end of each local step. If it is False, + we remove all layers that must be removed at the end of each local step. + This argument is prevalent above all others described below. + + - layers_to_save_on_disk: if this argument exists, contains the layers that + must be saved on disk at EVERY local step. It can be either None (in which + case the default behaviour is to save no layers) or a dictionary with a refit_adata + and local_adata key. The associated values contain either None (no layers) or + a list of layers to save at each step. + + This decorator adds two parameters to each method decorated with it: + - layers_to_load + - layers_to_save_on_disk + + If the layers_to_load is None, the default is to load all available layers. + Else, we only load the layers specified in the layers_to_load argument. + + The layers_to_save_on_disk argument is ADDED to the layers_to_save_on_disk attribute + of self for the duration of the method and then removed. That way, the inner + method can access the names of the layers_to_save_on_disk which will effectively + be saved at the end of the step. + + Parameters + ---------- + method : Callable + The method to decorate. This method is expected to have the following signature: + method(self, data_from_opener: ad.AnnData, shared_state: Any, + **method_parameters). + + Returns + ------- + Callable + The decorated method, which loads the simple layers before calling the method + and removes the simple layers after the method is called. + + """ + + @wraps(method) + def method_inner( + self, + data_from_opener: ad.AnnData, + shared_state: Any = None, + layers_to_load: LayersToLoadSaveType = None, + layers_to_save_on_disk: LayersToLoadSaveType = None, + **method_parameters, + ): + if layers_to_load is None: + layers_to_load = {"local_adata": None, "refit_adata": None} + if hasattr(self, "layers_to_save_on_disk"): + if self.layers_to_save_on_disk is None: + global_layers_to_save_on_disk = None + else: + global_layers_to_save_on_disk = self.layers_to_save_on_disk.copy() + + if global_layers_to_save_on_disk is None: + self.layers_to_save_on_disk = {"local_adata": [], "refit_adata": []} + else: + self.layers_to_save_on_disk = {"local_adata": [], "refit_adata": []} + + if layers_to_save_on_disk is None: + layers_to_save_on_disk = {"local_adata": [], "refit_adata": []} + + # Set the layers_to_save_on_disk attribute to the union of the layers specified + # in the argument and those in the attribute, to be accessed by the method. + assert isinstance(self.layers_to_save_on_disk, dict) + for adata_name in ["local_adata", "refit_adata"]: + if self.layers_to_save_on_disk[adata_name] is None: + self.layers_to_save_on_disk[adata_name] = [] + if layers_to_save_on_disk[adata_name] is None: + layers_to_save_on_disk[adata_name] = [] + self.layers_to_save_on_disk[adata_name] = list( + set( + layers_to_save_on_disk[adata_name] + + self.layers_to_save_on_disk[adata_name] + ) + ) + + # Check that the layers_to_load and layers_to_save are valid + assert set(layers_to_load.keys()) == {"local_adata", "refit_adata"} + assert set(self.layers_to_save_on_disk.keys()) == {"local_adata", "refit_adata"} + + # Load the counts of the adata + if self.local_adata is not None: + if self.local_adata.X is None: + self.local_adata.X = data_from_opener.X + + # Load the available layers + only_from_disk = ( + not hasattr(self, "save_layers_to_disk") or self.save_layers_to_disk + ) + + # Start by loading the local adata + check_and_load_layers( + self, "local_adata", layers_to_load, shared_state, only_from_disk + ) + + # Create the refit adata + reconstruct_refit_adata_without_layers(self) + + # Load the layers of the refit adata + check_and_load_layers( + self, "refit_adata", layers_to_load, shared_state, only_from_disk + ) + + # Apply the method + shared_state = method(self, data_from_opener, shared_state, **method_parameters) + + # Remove all layers which must not be saved on disk + for adata_name in ["local_adata", "refit_adata"]: + adata = getattr(self, adata_name) + if adata is None: + continue + if only_from_disk: + layers_to_save_on_disk_adata: list | None = list(adata.layers.keys()) + else: + layers_to_save_on_disk_adata = self.layers_to_save_on_disk[adata_name] + assert layers_to_save_on_disk_adata is not None + for layer in layers_to_save_on_disk_adata: + if layer not in adata.layers.keys(): + print("Warning: layer not in adata: ", layer) + assert layers_to_save_on_disk_adata is not None + remove_layers( + adata=adata, + layers_to_save_on_disk=layers_to_save_on_disk_adata, + refit=adata_name == "refit_adata", + ) + + # Reset the layers_to_save_on_disk attribute + try: + self.layers_to_save_on_disk = global_layers_to_save_on_disk + except NameError: + del self.layers_to_save_on_disk + + return shared_state + + return method_inner + + +def reconstruct_refit_adata_without_layers(self: Any): + """Reconstruct the refit adata without the layers. + + This function reconstructs the refit adata without the layers. + It is used to avoid the counts and the obsm being loaded uselessly in the + refit_adata. + + Parameters + ---------- + self : Any + The object containing the adata. + + """ + if self.refit_adata is None: + return + if self.local_adata is not None and "replaced" in self.local_adata.varm.keys(): + set_basic_refit_adata(self) + if self.local_adata is not None and "refitted" in self.local_adata.varm.keys(): + set_imputed_counts_refit_adata(self) + + +def check_and_load_layers( + self: Any, + adata_name: str, + layers_to_load: dict[str, list[str] | None], + shared_state: dict | None, + only_from_disk: bool, +): + """Check and load layers for a given adata_name. + + This function checks the availability of the layers to load + and loads them, for the adata_name adata. + + Parameters + ---------- + self : Any + The object containing the adata. + adata_name : str + The name of the adata to load the layers into. + layers_to_load : dict[str, Optional[list[str]]] + The layers to load for each adata. It must have adata_name + as a key. + shared_state : Optional[dict] + The shared state. + only_from_disk : bool + Whether to load only the layers from disk. + + """ + adata = getattr(self, adata_name) + layers_to_load_adata = layers_to_load[adata_name] + available_layers_adata = get_available_layers( + adata, + shared_state, + refit=adata_name == "refit_adata", + all_layers_from_disk=only_from_disk, + ) + if layers_to_load_adata is None: + layers_to_load_adata = available_layers_adata + else: + assert np.all( + [layer in available_layers_adata for layer in layers_to_load_adata] + ) + if adata is None: + return + assert layers_to_load_adata is not None + n_jobs, joblib_verbosity, joblib_backend, batch_size = get_joblib_parameters(self) + load_layers( + adata=adata, + shared_state=shared_state, + layers_to_load=layers_to_load_adata, + n_jobs=n_jobs, + joblib_verbosity=joblib_verbosity, + joblib_backend=joblib_backend, + batch_size=batch_size, + ) diff --git a/fedpydeseq2/core/utils/layers/utils.py b/fedpydeseq2/core/utils/layers/utils.py new file mode 100644 index 0000000..d599a27 --- /dev/null +++ b/fedpydeseq2/core/utils/layers/utils.py @@ -0,0 +1,214 @@ +import anndata as ad +import numpy as np + +from fedpydeseq2.core.utils.layers.build_layers import can_get_fit_lin_mu_hat +from fedpydeseq2.core.utils.layers.build_layers import can_get_mu_hat +from fedpydeseq2.core.utils.layers.build_layers import can_get_normed_counts +from fedpydeseq2.core.utils.layers.build_layers import can_get_sqerror_layer +from fedpydeseq2.core.utils.layers.build_layers import can_get_y_hat +from fedpydeseq2.core.utils.layers.build_layers import can_set_cooks_layer +from fedpydeseq2.core.utils.layers.build_layers import can_set_hat_diagonals_layer +from fedpydeseq2.core.utils.layers.build_layers import can_set_mu_layer +from fedpydeseq2.core.utils.layers.build_layers import set_cooks_layer +from fedpydeseq2.core.utils.layers.build_layers import set_fit_lin_mu_hat +from fedpydeseq2.core.utils.layers.build_layers import set_hat_diagonals_layer +from fedpydeseq2.core.utils.layers.build_layers import set_mu_hat_layer +from fedpydeseq2.core.utils.layers.build_layers import set_mu_layer +from fedpydeseq2.core.utils.layers.build_layers import set_normed_counts +from fedpydeseq2.core.utils.layers.build_layers import set_sqerror_layer +from fedpydeseq2.core.utils.layers.build_layers import set_y_hat + +AVAILABLE_LAYERS = [ + "normed_counts", + "_mu_LFC", + "_irls_mu_hat", + "sqerror", + "_y_hat", + "_fit_lin_mu_hat", + "_mu_hat", + "_hat_diagonals", + "cooks", +] + + +def get_available_layers( + adata: ad.AnnData | None, + shared_state: dict | None, + refit: bool = False, + all_layers_from_disk: bool = False, +) -> list[str]: + """Get the available layers in the adata. + + Parameters + ---------- + adata : Optional[ad.AnnData] + The local adata. + + shared_state : dict + The shared state containing the Cook's dispersion values. + + refit : bool + Whether to refit the layers. + + all_layers_from_disk : bool + Whether to get all layers from disk. + + Returns + ------- + list[str] + List of available layers. + + """ + if adata is None: + return [] + if all_layers_from_disk: + return list(adata.layers.keys()) + available_layers = [] + if can_get_normed_counts(adata, raise_error=False): + available_layers.append("normed_counts") + if can_get_y_hat(adata, raise_error=False): + available_layers.append("_y_hat") + if can_get_mu_hat(adata, raise_error=False): + available_layers.append("_mu_hat") + if can_get_fit_lin_mu_hat(adata, raise_error=False): + available_layers.append("_fit_lin_mu_hat") + if can_get_sqerror_layer(adata, raise_error=False): + available_layers.append("sqerror") + if not refit and can_set_cooks_layer( + adata, shared_state=shared_state, raise_error=False + ): + available_layers.append("cooks") + if not refit and can_set_hat_diagonals_layer( + adata, shared_state=shared_state, raise_error=False + ): + available_layers.append("_hat_diagonals") + if can_set_mu_layer( + adata, lfc_param_name="LFC", mu_param_name="_mu_LFC", raise_error=False + ): + available_layers.append("_mu_LFC") + if can_set_mu_layer( + adata, + lfc_param_name="_mu_hat_LFC", + mu_param_name="_irls_mu_hat", + raise_error=False, + ): + available_layers.append("_irls_mu_hat") + + return available_layers + + +def load_layers( + adata: ad.AnnData, + shared_state: dict | None, + layers_to_load: list[str], + n_jobs: int = 1, + joblib_verbosity: int = 0, + joblib_backend: str = "loky", + batch_size: int = 100, +): + """Load the simple layers from the data_from_opener and the adata object. + + This function loads the layers in the layers_to_load attribute in the + adata object. + + Parameters + ---------- + adata : ad.AnnData + The AnnData object to load the layers into. + + shared_state : dict or None + The shared state containing the Cook's dispersion values. + + layers_to_load : list[str] + The list of layers to load. + + n_jobs : int + The number of jobs to use for parallel processing. + + joblib_verbosity : int + The verbosity level of joblib. + + joblib_backend : str + The joblib backend to use. + + batch_size : int + The batch size for parallel processing. + + """ + # Assert that all layers are either complex or simple + assert np.all( + layer in AVAILABLE_LAYERS for layer in layers_to_load + ), f"All layers in layers_to_load must be in {AVAILABLE_LAYERS}" + + if "normed_counts" in layers_to_load: + set_normed_counts(adata=adata) + if "_mu_LFC" in layers_to_load: + set_mu_layer( + local_adata=adata, + lfc_param_name="LFC", + mu_param_name="_mu_LFC", + n_jobs=n_jobs, + joblib_verbosity=joblib_verbosity, + joblib_backend=joblib_backend, + batch_size=batch_size, + ) + if "_irls_mu_hat" in layers_to_load: + set_mu_layer( + local_adata=adata, + lfc_param_name="_mu_hat_LFC", + mu_param_name="_irls_mu_hat", + n_jobs=n_jobs, + joblib_verbosity=joblib_verbosity, + joblib_backend=joblib_backend, + batch_size=batch_size, + ) + if "sqerror" in layers_to_load: + set_sqerror_layer(adata) + if "_y_hat" in layers_to_load: + set_y_hat(adata) + if "_fit_lin_mu_hat" in layers_to_load: + set_fit_lin_mu_hat(adata) + if "_mu_hat" in layers_to_load: + set_mu_hat_layer(adata) + if "_hat_diagonals" in layers_to_load: + set_hat_diagonals_layer(adata=adata, shared_state=shared_state) + if "cooks" in layers_to_load: + set_cooks_layer(adata=adata, shared_state=shared_state) + + +def remove_layers( + adata: ad.AnnData, + layers_to_save_on_disk: list[str], + refit: bool = False, +): + """Remove the simple layers from the adata object. + + This function removes the simple layers from the adata object. The layers_to_save + parameter can be used to specify which layers to save in the local state. + If layers_to_save is None, no layers are saved. + + This function also adds all present layers to the _available_layers field in the + adata object. This field is used to keep track of the layers that are present in + the adata object. + + Parameters + ---------- + adata : ad.AnnData + The AnnData object to remove the layers from. + + refit : bool + Whether the adata object is the refit_adata object. + + layers_to_save_on_disk : list[str] + The list of layers to save. If None, no layers are saved. + + """ + adata.X = None + if refit: + adata.obsm = None + + layer_names = list(adata.layers.keys()).copy() + for layer_name in layer_names: + if layer_name in layers_to_save_on_disk: + continue + del adata.layers[layer_name] diff --git a/fedpydeseq2/core/utils/logging/__init__.py b/fedpydeseq2/core/utils/logging/__init__.py new file mode 100644 index 0000000..190db90 --- /dev/null +++ b/fedpydeseq2/core/utils/logging/__init__.py @@ -0,0 +1,3 @@ +from fedpydeseq2.core.utils.logging.logging_decorators import log_remote +from fedpydeseq2.core.utils.logging.logging_decorators import log_remote_data +from fedpydeseq2.core.utils.logging.logging_decorators import log_save_local_state diff --git a/fedpydeseq2/core/utils/logging/default_config.ini b/fedpydeseq2/core/utils/logging/default_config.ini new file mode 100644 index 0000000..42fd37a --- /dev/null +++ b/fedpydeseq2/core/utils/logging/default_config.ini @@ -0,0 +1,21 @@ +[loggers] +keys=root + +[handlers] +keys=consoleHandler + +[formatters] +keys=sampleFormatter + +[logger_root] +level=WARNING +handlers=consoleHandler + +[handler_consoleHandler] +class=StreamHandler +level=WARNING +formatter=sampleFormatter +args=(sys.stdout,) + +[formatter_sampleFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s diff --git a/fedpydeseq2/core/utils/logging/logging_decorators.py b/fedpydeseq2/core/utils/logging/logging_decorators.py new file mode 100644 index 0000000..67dba90 --- /dev/null +++ b/fedpydeseq2/core/utils/logging/logging_decorators.py @@ -0,0 +1,215 @@ +""" +Module containing decorators to log the input and outputs of a method. + +All logging is controlled through a logging configuration file. +This configuration file can be either set by the log_config_path attribute of the class, +or by the default_config.ini file in the same directory as this module. +""" +import logging +import logging.config +import os +import pathlib +from collections.abc import Callable +from functools import wraps +from typing import Any + +import anndata as ad + + +def log_save_local_state(method: Callable): + """ + Decorate a method to log the size of the local state saved. + + This function is destined to decorate the save_local_state method of a class. + + It logs the size of the local state saved in the local state path, in MB. + This is logged as an info message. + + Parameters + ---------- + method : Callable + The method to decorate. This method is expected to have the following signature: + method(self, path: pathlib.Path). + + Returns + ------- + Callable + The decorated method, which logs the size of the local state saved. + + """ + + @wraps(method) + def remote_method_inner( + self, + path: pathlib.Path, + ): + logger = get_method_logger(self, method) + + output = method(self, path) + + logger.info( + f"Size of local state saved : " + f"{os.path.getsize(path) / 1024 / 1024}" + " MB" + ) + + return output + + return remote_method_inner + + +def log_remote_data(method: Callable): + """ + Decorate a remote_data to log the input and outputs. + + This decorator logs the shared state keys with the info level, + and the different layers of the local_adata and refit_adata with the debug level. + + This is done before and after the method call. + + Parameters + ---------- + method : Callable + The method to decorate. This method is expected to have the following signature: + method(self, data_from_opener: ad.AnnData, + shared_state: Any = None, **method_parameters). + + Returns + ------- + Callable + The decorated method, which logs the shared state keys with the info level + and the different layers of the local_adata and refit_adata with the debug + level. + """ + + @wraps(method) + def remote_method_inner( + self, + data_from_opener: ad.AnnData, + shared_state: Any = None, + **method_parameters, + ): + logger = get_method_logger(self, method) + logger.info("---- Before running the method ----") + log_shared_state_adatas(self, method, shared_state) + + shared_state = method(self, data_from_opener, shared_state, **method_parameters) + + logger.info("---- After method ----") + log_shared_state_adatas(self, method, shared_state) + return shared_state + + return remote_method_inner + + +def log_remote(method: Callable): + """ + Decorate a remote method to log the input and outputs. + + This decorator logs the shared state keys with the info level. + + Parameters + ---------- + method : Callable + The method to decorate. This method is expected to have the following signature: + method(self, shared_states: Optional[list], **method_parameters). + + Returns + ------- + Callable + The decorated method, which logs the shared state keys with the info level. + + """ + + @wraps(method) + def remote_method_inner( + self, + shared_states: list | None, + **method_parameters, + ): + logger = get_method_logger(self, method) + if shared_states is not None: + shared_state = shared_states[0] + if shared_state is not None: + logger.info( + f"First input shared state keys : {list(shared_state.keys())}" + ) + else: + logger.info("First input shared state is None.") + else: + logger.info("No input shared states.") + + shared_state = method(self, shared_states, **method_parameters) + + if shared_state is not None: + logger.info(f"Output shared state keys : {list(shared_state.keys())}") + else: + logger.info("No output shared state.") + + return shared_state + + return remote_method_inner + + +def log_shared_state_adatas(self: Any, method: Callable, shared_state: dict | None): + """ + Log the information of the local step. + + Precisely, log the shared state keys (info), + and the different layers of the local_adata and refit_adata (debug). + + Parameters + ---------- + self : Any + The class instance + method : Callable + The class method. + shared_state : Optional[dict] + The shared state dictionary, whose keys we log with the info level. + + """ + logger = get_method_logger(self, method) + + if shared_state is not None: + logger.info(f"Shared state keys : {list(shared_state.keys())}") + else: + logger.info("No shared state") + + for adata_name in ["local_adata", "refit_adata"]: + if hasattr(self, adata_name) and getattr(self, adata_name) is not None: + adata = getattr(self, adata_name) + logger.debug(f"{adata_name} layers : {list(adata.layers.keys())}") + if "_available_layers" in self.local_adata.uns: + available_layers = self.local_adata.uns["_available_layers"] + logger.debug(f"{adata_name} available layers : {available_layers}") + logger.debug(f"{adata_name} uns keys : {list(adata.uns.keys())}") + logger.debug(f"{adata_name} varm keys : {list(adata.varm.keys())}") + logger.debug(f"{adata_name} obsm keys : {list(adata.obsm.keys())}") + + +def get_method_logger(self: Any, method: Callable) -> logging.Logger: + """ + Get the method logger from a configuration file. + + If the class instance has a log_config_path attribute, + the logger is configured with the file at this path. + + Parameters + ---------- + self: Any + The class instance + method: Callable + The class method. + + Returns + ------- + logging.Logger + The logger instance. + """ + if hasattr(self, "log_config_path"): + log_config_path = pathlib.Path(self.log_config_path) + else: + log_config_path = pathlib.Path(__file__).parent / "default_config.ini" + logging.config.fileConfig(log_config_path, disable_existing_loggers=False) + logger = logging.getLogger(method.__name__) + return logger diff --git a/fedpydeseq2/core/utils/mle.py b/fedpydeseq2/core/utils/mle.py new file mode 100644 index 0000000..ba1b347 --- /dev/null +++ b/fedpydeseq2/core/utils/mle.py @@ -0,0 +1,305 @@ +import numpy as np +from pydeseq2.utils import dnb_nll +from pydeseq2.utils import nb_nll + +from fedpydeseq2.core.utils.negative_binomial import grid_nb_nll +from fedpydeseq2.core.utils.negative_binomial import vec_nb_nll_grad + + +def vec_loss( + counts: np.ndarray, + design: np.ndarray, + mu: np.ndarray, + alpha: np.ndarray, + cr_reg: bool = True, + prior_reg: bool = False, + alpha_hat: np.ndarray | None = None, + prior_disp_var: float | None = None, +) -> np.ndarray: + """Compute the adjusted negative log likelihood of a batch of genes. + + Includes Cox-Reid regularization and (optionally) prior regularization. + + Parameters + ---------- + counts : ndarray + Raw counts for a set of genes (n_samples x n_genes). + + design : ndarray + Design matrix (n_samples x n_params). + + mu : ndarray + Mean estimation for the NB model (n_samples x n_genes). + + alpha : ndarray + Dispersion estimates (n_genes). + + cr_reg : bool + Whether to include Cox-Reid regularization (default: True). + + prior_reg : bool + Whether to include prior regularization (default: False). + + alpha_hat : ndarray, optional + Reference dispersions (for MAP estimation, n_genes). + + prior_disp_var : float, optional + Prior dispersion variance. + + Returns + ------- + ndarray + Adjusted negative log likelihood (n_genes). + """ + # closure to be minimized + reg = 0 + if cr_reg: + W = mu / (1 + mu * alpha) + reg += ( + 0.5 + * np.linalg.slogdet((design.T[:, :, None] * W).transpose(2, 0, 1) @ design)[ + 1 + ] + ) + if prior_reg: + if prior_disp_var is None: + raise ValueError("Sigma_prior is required for prior regularization") + reg += (np.log(alpha) - np.log(alpha_hat)) ** 2 / (2 * prior_disp_var) + return nb_nll(counts, mu, alpha) + reg + + +def local_grid_summands( + counts: np.ndarray, + design: np.ndarray, + mu: np.ndarray, + alpha_grid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute local summands of the adjusted negative log likelihood on a grid. + + Includes the Cox-Reid regularization. + + Parameters + ---------- + counts : ndarray + Raw counts for a set of genes (n_samples x n_genes). + + design : ndarray + Design matrix (n_samples x n_params). + + mu : ndarray + Mean estimation for the NB model (n_samples x n_genes). + + alpha_grid : ndarray + Dispersion estimates (n_genes x grid_length). + + Returns + ------- + nll : ndarray + Negative log likelihoods of size (n_genes x grid_length). + + cr_matrix : ndarray + Summands for the Cox-Reid adjustment + (n_genes x grid_length x n_params x n_params). + """ + # W is of size (n_samples x n_genes x grid_length) + W = mu[:, :, None] / (1 + mu[:, :, None] * alpha_grid) + # cr_matrix is of size (n_genes x grid_length x n_params x n_params) + cr_matrix = (design.T[:, :, None, None] * W).transpose(2, 3, 0, 1) @ design[ + None, None, :, : + ] + # cr_matrix is of size (n_genes x grid_length) + nll = grid_nb_nll(counts, mu, alpha_grid) + + return nll, cr_matrix + + +def global_grid_cr_loss( + nll: np.ndarray, + cr_grid: np.ndarray, +) -> np.ndarray: + """Compute the global negative log likelihood on a grid. + + Sums previously computed local negative log likelihoods and Cox-Reid adjustments. + + Parameters + ---------- + nll : ndarray + Negative log likelihoods of size (n_genes x grid_length). + + cr_grid : ndarray + Summands for the Cox-Reid adjustment + (n_genes x grid_length x n_params x n_params). + + Returns + ------- + ndarray + Adjusted negative log likelihood (n_genes x grid_length). + """ + if np.any(np.isnan(cr_grid)): + n_genes, grid_length, n_params, _ = cr_grid.shape + cr_grid = cr_grid.reshape(-1, n_params, n_params) + mask_nan = np.any(np.isnan(cr_grid), axis=(1, 2)) + slogdet = np.zeros(n_genes * grid_length, dtype=cr_grid.dtype) + slogdet[mask_nan] = np.nan + if np.any(~mask_nan): + slogdet[~mask_nan] = np.linalg.slogdet(cr_grid[~mask_nan])[1] + return nll + 0.5 * slogdet.reshape(n_genes, grid_length) + else: + return nll + 0.5 * np.linalg.slogdet(cr_grid)[1] + + +def single_mle_grad( + counts: np.ndarray, design: np.ndarray, mu: np.ndarray, alpha: float +) -> tuple[float, np.ndarray, np.ndarray]: + r"""Estimate the local gradients of a negative binomial GLM wrt dispersions. + + Returns both the gradient of the negative likelihood, and two matrices used to + compute the gradient of the Cox-Reid adjustment. + + + Parameters + ---------- + counts : ndarray + Raw counts for a given gene (n_samples). + + design : ndarray + Design matrix (n_samples x n_params). + + mu : ndarray + Mean estimation for the NB model (n_samples). + + alpha : float + Initial dispersion estimate (1). + + Returns + ------- + grad : ndarray + Gradient of the negative log likelihood of the observations counts following + :math:`NB(\\mu, \\alpha)` (1). + + M1 : ndarray + First summand for the gradient of the CR adjustment (n_params x n_params). + + M2 : ndarray + Second summand for the gradient of the CR adjustment (n_params x n_params). + """ + grad = alpha * dnb_nll(counts, mu, alpha) + W = mu / (1 + mu * alpha) + dW = -(W**2) + M1 = (design.T * W) @ design + M2 = (design.T * dW) @ design + + return grad, M1, M2 + + +def batch_mle_grad( + counts: np.ndarray, design: np.ndarray, mu: np.ndarray, alpha: np.ndarray +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + r"""Estimate the local gradients wrt dispersions on a batch of genes. + + Returns both the gradient of the negative likelihood, and two matrices used to + compute the gradient of the Cox-Reid adjustment. + + + Parameters + ---------- + counts : ndarray + Raw counts for a set of genes (n_samples x n_genes). + + design : ndarray + Design matrix (n_samples x n_params). + + mu : ndarray + Mean estimation for the NB model (n_samples x n_genes). + + alpha : float + Initial dispersion estimate (nn_genes). + + Returns + ------- + grad : ndarray + Gradient of the negative log likelihood of the observations counts following + :math:`NB(\\mu, \\alpha)` (n_genes). + + M1 : ndarray + First summand for the gradient of the CR adjustment + (n_genes x n_params x n_params). + + M2 : ndarray + Second summand for the gradient of the CR adjustment + (n_genes x n_params x n_params). + """ + grad = alpha * vec_nb_nll_grad( + counts, + mu, + alpha, + ) # Need to multiply by alpha to get the gradient wrt log_alpha + + W = mu / (1 + mu * alpha[None, :]) + + dW = -(W**2) + M1 = (design.T[:, :, None] * W).transpose(2, 0, 1) @ design[None, :, :] + M2 = (design.T[:, :, None] * dW).transpose(2, 0, 1) @ design[None, :, :] + + return grad, M1, M2 + + +def batch_mle_update( + log_alpha: np.ndarray, + global_CR_summand_1: np.ndarray, + global_CR_summand_2: np.ndarray, + global_ll_grad: np.ndarray, + lr: float, + alpha_hat: np.ndarray | None = None, + prior_disp_var: float | None = None, + prior_reg: bool = False, +): + """Perform a global dispersions update on a batch of genes. + + Parameters + ---------- + log_alpha : ndarray + Current global log dispersions (n_genes). + + global_CR_summand_1 : ndarray + Global summand 1 for the CR adjustment (n_genes x n_params x n_params). + + global_CR_summand_2 : ndarray + Global summand 2 for the CR adjustment (n_genes x n_params x n_params). + + global_ll_grad : ndarray + Global gradient of the negative log likelihood (n_genes). + + lr : float + Learning rate. + + alpha_hat : ndarray + Reference dispersions (for MAP estimation, n_genes). + + prior_disp_var : float + Prior dispersion variance. + + prior_reg : bool + Whether to use prior regularization for MAP estimation (default: ``False``). + + Returns + ------- + ndarray + Updated global log dispersions (n_genes). + + """ + # Add prior regularization, if required + if prior_reg: + global_ll_grad += (log_alpha - np.log(alpha_hat)) / prior_disp_var + + # Compute CR reg grad (not separable, cannot be computed locally) + global_CR_grad = np.array( + 0.5 + * (np.linalg.inv(global_CR_summand_1) * global_CR_summand_2).sum(1).sum(1) + * np.exp(log_alpha) + ) + + # Update dispersion + global_log_alpha = log_alpha - lr * (global_ll_grad + global_CR_grad) + + return global_log_alpha diff --git a/fedpydeseq2/core/utils/negative_binomial.py b/fedpydeseq2/core/utils/negative_binomial.py new file mode 100644 index 0000000..c0304c3 --- /dev/null +++ b/fedpydeseq2/core/utils/negative_binomial.py @@ -0,0 +1,149 @@ +"""Gradients and loss functions for the negative binomial distribution.""" + +import numpy as np +from scipy.special import gammaln # type: ignore +from scipy.special import polygamma + + +def vec_nb_nll_grad( + counts: np.ndarray, mu: np.ndarray, alpha: np.ndarray +) -> np.ndarray: + r"""Return the gradient of the negative log-likelihood of a negative binomial. + + Vectorized version (wrt genes). + + Parameters + ---------- + counts : ndarray + Observations, n_samples x n_genes. + + mu : ndarray + Mean of the distribution. + + alpha : pd.Series + Dispersion of the distribution, s.t. the variance is + :math:`\\mu + \\alpha_grid * \\mu^2`. + + Returns + ------- + ndarray + Gradient of the negative log likelihood of the observations counts following + :math:`NB(\\mu, \\alpha_grid)`. + """ + alpha_neg1 = 1 / alpha + ll_part = alpha_neg1**2 * ( + polygamma(0, alpha_neg1[None, :]) + - polygamma(0, counts + alpha_neg1[None, :]) + + np.log(1 + mu * alpha[None, :]) + + (counts - mu) / (mu + alpha_neg1[None, :]) + ).sum(0) + + return -ll_part + + +def grid_nb_nll( + counts: np.ndarray, + mu: np.ndarray, + alpha_grid: np.ndarray, + mask_nan: np.ndarray | None = None, +) -> np.ndarray: + r"""Neg log-likelihood of a negative binomial, batched wrt genes on a grid. + + Parameters + ---------- + counts : ndarray + Observations, n_samples x n_genes. + + mu : ndarray + Mean estimation for the NB model (n_samples x n_genes). + + alpha_grid : ndarray + Dispersions (n_genes x grid_length). + + mask_nan : ndarray + Mask for the values of the grid where mu should have taken values >> 1. + + Returns + ------- + ndarray + Negative log likelihoods of size (n_genes x grid_length). + """ + n = len(counts) + alpha_neg1 = 1 / alpha_grid + ndim_alpha = alpha_grid.ndim + extra_dims_counts = tuple(range(2, 2 + ndim_alpha - 1)) + expanded_counts = np.expand_dims(counts, axis=extra_dims_counts) + # In order to avoid infinities, we replace all big values in the mu with 1 and + # modify the final quantity with their true value for the inputs were mu should have + # taken values >> 1 + if mask_nan is not None: + mu[mask_nan] = 1.0 + expanded_mu = np.expand_dims(mu, axis=extra_dims_counts) + logbinom = ( + gammaln(expanded_counts + alpha_neg1[None, :]) + - gammaln(expanded_counts + 1) + - gammaln(alpha_neg1[None, :]) + ) + + nll = n * alpha_neg1 * np.log(alpha_grid) + ( + -logbinom + + (expanded_counts + alpha_neg1) * np.log(alpha_neg1 + expanded_mu) + - expanded_counts * np.log(expanded_mu) + ).sum(0) + if mask_nan is not None: + nll[mask_nan.sum(0) > 0] = np.nan + return nll + + +def mu_grid_nb_nll( + counts: np.ndarray, mu_grid: np.ndarray, alpha: np.ndarray +) -> np.ndarray: + r"""Compute the neg log-likelihood of a negative binomial. + + This function is *batched* wrt genes on a mu grid. + + Parameters + ---------- + counts : ndarray + Observations, (n_obs, batch_size). + + mu_grid : ndarray + Means of the distribution :math:`\\mu`, (n_mu, batch_size, n_obs). + + alpha : ndarray + Dispersions of the distribution :math:`\\alpha`, + s.t. the variance is :math:`\\mu + \\alpha \\mu^2`, + of size (batch_size,). + + Returns + ------- + ndarray + Negative log likelihoods of the observations counts + following :math:`NB(\\mu, \\alpha)`, of size (n_mu, batch_size). + + Notes + ----- + [1] https://en.wikipedia.org/wiki/Negative_binomial_distribution + """ + n = len(counts) + alpha_neg1 = 1 / alpha # shape (batch_size,) + logbinom = np.expand_dims( + ( + gammaln(counts.T + alpha_neg1[:, None]) + - gammaln(counts.T + 1) + - gammaln(alpha_neg1[:, None]) + ), + axis=0, + ) # Of size (1, batch_size, n_obs) + first_term = np.expand_dims( + n * alpha_neg1 * np.log(alpha), axis=0 + ) # Of size (1, batch_size) + second_term = np.expand_dims( + counts.T + np.expand_dims(alpha_neg1, axis=1), axis=0 + ) * np.log( + np.expand_dims(alpha_neg1, axis=(0, 2)) + mu_grid + ) # Of size (n_mu, batch_size, n_obs) + third_term = -np.expand_dims(counts.T, axis=0) * np.log( + mu_grid + ) # Of size (n_mu, batch_size, n_obs) + return first_term + (-logbinom + second_term + third_term).sum(axis=2) diff --git a/fedpydeseq2/core/utils/pass_on_results.py b/fedpydeseq2/core/utils/pass_on_results.py new file mode 100644 index 0000000..c98f2ba --- /dev/null +++ b/fedpydeseq2/core/utils/pass_on_results.py @@ -0,0 +1,39 @@ +"""Module to implement the passing of the first shared state. + +# TODO remove after all savings have been factored out, if not needed anymore. +""" + +from substrafl.remote import remote + +from fedpydeseq2.core.utils.logging import log_remote + + +class AggPassOnResults: + """Mixin to pass on the first shared state.""" + + results: dict | None + + @remote + @log_remote + def pass_on_results(self, shared_states: list[dict]) -> dict: + """Pass on the shared state. + + This method simply returns the first shared state. + + Parameters + ---------- + shared_states : list + List of shared states. + + Returns + ------- + dict : The first shared state. + + """ + results = shared_states[0] + # This is an ugly way to save the results for the simulation mode. + # In simulation mode, we will look at the results attribute of the class + # to get the results. + # In the real mode, we will download the last shared state. + self.results = results + return results diff --git a/fedpydeseq2/core/utils/pipe_steps.py b/fedpydeseq2/core/utils/pipe_steps.py new file mode 100644 index 0000000..15c3f5d --- /dev/null +++ b/fedpydeseq2/core/utils/pipe_steps.py @@ -0,0 +1,140 @@ +from collections.abc import Callable + +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.nodes.references.shared_state import SharedStateRef + + +def local_step( + local_method: Callable, + train_data_nodes: list[TrainDataNode], + output_local_states: dict[str, LocalStateRef], + round_idx: int, + input_local_states: dict[str, LocalStateRef] | None = None, + input_shared_state: SharedStateRef | None = None, + aggregation_id: str | None = None, + description: str = "", + clean_models: bool = True, + method_params: dict | None = None, +) -> tuple[dict[str, LocalStateRef], list[SharedStateRef], int]: + """Local step of the federated learning strategy. + + Used as a wrapper to execute a local method on the data of each organization. + + Parameters + ---------- + local_method : Callable + Method to be executed on the local data. + train_data_nodes : TrainDataNode + List of TrainDataNode. + output_local_states : dict + Dictionary of local states to be updated. + round_idx : int + Round index. + input_local_states : dict, optional + Dictionary of local states to be used as input. + input_shared_state : SharedStateRef, optional + Shared state to be used as input. + aggregation_id : str, optional + Aggregation node id. + description : str + Description of the algorithm. + clean_models : bool + Whether to clean the models after the computation. + method_params : dict, optional + Optional keyword arguments to be passed to the local method. + + Returns + ------- + output_local_states : dict + Local states containing the results of the local method, + to keep within the training nodes. + output_shared_states : list + Shared states containing the results of the local method, + to be sent to the aggregation node. + round_idx : int + Round index incremented by 1 + """ + output_shared_states = [] + method_params = method_params or {} + + for node in train_data_nodes: + next_local_state, next_shared_state = node.update_states( + local_method( + node.data_sample_keys, + shared_state=input_shared_state, + _algo_name=description, + **method_params, + ), + local_state=( + input_local_states[node.organization_id] if input_local_states else None + ), + round_idx=round_idx, + authorized_ids={node.organization_id}, + aggregation_id=aggregation_id, + clean_models=clean_models, + ) + + output_local_states[node.organization_id] = next_local_state + output_shared_states.append(next_shared_state) + + round_idx += 1 + return output_local_states, output_shared_states, round_idx + + +def aggregation_step( + aggregation_method: Callable, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + input_shared_states: list[SharedStateRef], + round_idx: int, + description: str = "", + clean_models: bool = True, + method_params: dict | None = None, +) -> tuple[SharedStateRef, int]: + """Perform an aggregation step of the federated learning strategy. + + Used as a wrapper to execute an aggregation method on the data of each organization. + + Parameters + ---------- + aggregation_method : Callable + Method to be executed on the shared states. + train_data_nodes : list + List of TrainDataNode. + aggregation_node : AggregationNode + Aggregation node. + input_shared_states : list + List of shared states to be aggregated. + round_idx : int + Round index. + description: str + Description of the algorithm. + clean_models : bool + Whether to clean the models after the computation. + method_params : dict, optional + Optional keyword arguments to be passed to the aggregation method. + + Returns + ------- + SharedStateRef + A shared state containing the results of the aggregation. + round_idx : int + Round index incremented by 1 + """ + method_params = method_params or {} + share_state = aggregation_node.update_states( + aggregation_method( + shared_states=input_shared_states, + _algo_name=description, + **method_params, + ), + round_idx=round_idx, + authorized_ids={ + train_data_node.organization_id for train_data_node in train_data_nodes + }, + clean_models=clean_models, + ) + round_idx += 1 + return share_state, round_idx diff --git a/fedpydeseq2/core/utils/stat_utils.py b/fedpydeseq2/core/utils/stat_utils.py new file mode 100644 index 0000000..cd80b20 --- /dev/null +++ b/fedpydeseq2/core/utils/stat_utils.py @@ -0,0 +1,211 @@ +from typing import Literal + +import numpy as np +from scipy.stats import norm # type: ignore + + +def build_contrast( + design_factors, + design_columns, + continuous_factors=None, + contrast: list[str] | None = None, +) -> list[str]: + """Check the validity of the contrast (if provided). + + If not, build a default + contrast, corresponding to the last column of the design matrix. + A contrast should be a list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'reference_level']``. + Names must correspond to the metadata data passed to the FedCenters. + E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B' + compared to 'condition A'. + For continuous variables, the last two strings will be left empty, e.g. + ``['measurement', '', '']. + If None, the last variable from the design matrix + is chosen as the variable of interest, and the reference level is picked + alphabetically. + + Parameters + ---------- + design_factors : list + The design factors. + design_columns : list + The names of the columns of the design matrices in the centers. + continuous_factors : list or None + The continuous factors in the design, if any. (default: ``None``). + contrast : list or None + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'reference_level']``. + (default: ``None``). + """ + if contrast is not None: # Test contrast if provided + if len(contrast) != 3: + raise ValueError("The contrast should contain three strings.") + if contrast[0] not in design_factors: + raise KeyError( + f"The contrast variable ('{contrast[0]}') should be one " + f"of the design factors." + ) + # TODO: Ideally, we should check that the levels are valid. This might leak + # data from the centers, though. + + else: # Build contrast if None + factor = design_factors[-1] + # Check whether this factor is categorical or continuous. + if continuous_factors is not None and factor in continuous_factors: + # The factor is continuous + contrast = [factor, "", ""] + else: + # The factor is categorical + factor_col = next(col for col in design_columns if col.startswith(factor)) + split_col = factor_col.split("_") + contrast = [split_col[0], split_col[1], split_col[-1]] + + return contrast + + +def build_contrast_vector(contrast, LFC_columns) -> tuple[np.ndarray, int | None]: + """ + Build a vector corresponding to the desired contrast. + + Allows to test any pair of levels without refitting LFCs. + + Parameters + ---------- + contrast : list + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'reference_level']``. + LFC_columns : list + The names of the columns of the LFC matrices in the centers. + + Returns + ------- + contrast_vector : ndarray + The contrast vector, containing multipliers to apply to the LFCs. + contrast_idx : int or None + The index of the tested contrast in the LFC matrix. + """ + factor = contrast[0] + alternative = contrast[1] + ref = contrast[2] + if ref == alternative == "": + # "factor" is a continuous variable + contrast_level = factor + else: + contrast_level = f"{factor}_{alternative}_vs_{ref}" + + contrast_vector = np.zeros(len(LFC_columns)) + if contrast_level in LFC_columns: + contrast_idx = LFC_columns.get_loc(contrast_level) + contrast_vector[contrast_idx] = 1 + elif f"{factor}_{ref}_vs_{alternative}" in LFC_columns: + # Reference and alternative are inverted + contrast_idx = LFC_columns.get_loc(f"{factor}_{ref}_vs_{alternative}") + contrast_vector[contrast_idx] = -1 + else: + # Need to change reference + # Get any column corresponding to the desired factor and extract old ref + old_ref = next(col for col in LFC_columns if col.startswith(factor)).split( + "_vs_" + )[-1] + new_alternative_idx = LFC_columns.get_loc( + f"{factor}_{alternative}_vs_{old_ref}" + ) + new_ref_idx = LFC_columns.get_loc(f"{factor}_{ref}_vs_{old_ref}") + contrast_vector[new_alternative_idx] = 1 + contrast_vector[new_ref_idx] = -1 + # In that case there is no contrast index + contrast_idx = None + + return contrast_vector, contrast_idx + + +def wald_test( + M: np.ndarray, + lfc: np.ndarray, + ridge_factor: np.ndarray | None, + contrast_vector: np.ndarray, + lfc_null: float, + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] | None, +) -> tuple[float, float, float]: + """Run Wald test for a single gene. + + Computes Wald statistics, standard error and p-values from + dispersion and LFC estimates. + + Parameters + ---------- + M : ndarray + Central parameter in the covariance matrix estimator. + + lfc : ndarray + Log-fold change estimate (in natural log scale). + + ridge_factor : ndarray or None + Regularization factors. + + contrast_vector : ndarray + Vector encoding the contrast that is being tested. + + lfc_null : float + The log fold change (in natural log scale) under the null hypothesis. + + alt_hypothesis : str or None + The alternative hypothesis for computing wald p-values. + + Returns + ------- + wald_p_value : float + Estimated p-value. + + wald_statistic : float + Wald statistic. + + wald_se : float + Standard error of the Wald statistic. + """ + # Build covariance matrix estimator + + if ridge_factor is None: + ridge_factor = np.diag(np.repeat(1e-6, M.shape[0])) + H = np.linalg.inv(M + ridge_factor) + Hc = H @ contrast_vector + # Evaluate standard error and Wald statistic + wald_se: float = np.sqrt(Hc.T @ M @ Hc) + + def greater(lfc_null): + stat = contrast_vector @ np.fmax((lfc - lfc_null) / wald_se, 0) + pval = norm.sf(stat) + return stat, pval + + def less(lfc_null): + stat = contrast_vector @ np.fmin((lfc - lfc_null) / wald_se, 0) + pval = norm.sf(np.abs(stat)) + return stat, pval + + def greater_abs(lfc_null): + stat = contrast_vector @ ( + np.sign(lfc) * np.fmax((np.abs(lfc) - lfc_null) / wald_se, 0) + ) + pval = 2 * norm.sf(np.abs(stat)) # Only case where the test is two-tailed + return stat, pval + + def less_abs(lfc_null): + stat_above, pval_above = greater(-abs(lfc_null)) + stat_below, pval_below = less(abs(lfc_null)) + return min(stat_above, stat_below, key=abs), max(pval_above, pval_below) + + wald_statistic: float + wald_p_value: float + if alt_hypothesis: + wald_statistic, wald_p_value = { + "greaterAbs": greater_abs(lfc_null), + "lessAbs": less_abs(lfc_null), + "greater": greater(lfc_null), + "less": less(lfc_null), + }[alt_hypothesis] + else: + wald_statistic = contrast_vector @ (lfc - lfc_null) / wald_se + wald_p_value = 2 * norm.sf(np.abs(wald_statistic)) + + return wald_p_value, wald_statistic, wald_se diff --git a/fedpydeseq2/fedpydeseq2_pipeline.py b/fedpydeseq2/fedpydeseq2_pipeline.py new file mode 100644 index 0000000..055429a --- /dev/null +++ b/fedpydeseq2/fedpydeseq2_pipeline.py @@ -0,0 +1,149 @@ +from pathlib import Path + +import yaml # type: ignore +from substra.sdk.schemas import BackendType + +from fedpydeseq2.core.deseq2_strategy import DESeq2Strategy +from fedpydeseq2.substra_utils.federated_experiment import run_federated_experiment + + +def run_fedpydeseq2_experiment( + n_centers: int = 2, + backend: BackendType = "subprocess", + register_data: bool = False, + simulate: bool = True, + asset_directory: Path | None = None, + centers_root_directory: Path | None = None, + compute_plan_name: str = "FedPyDESeq2Experiment", + dataset_name: str = "MyDatasetName", + remote_timeout: int = 86400, # 24 hours + clean_models: bool = True, + save_filepath: str | Path | None = None, + credentials_path: str | Path | None = None, + dataset_datasamples_keys_path: str | Path | None = None, + cp_id_path: str | Path | None = None, + parameter_file: str | Path | None = None, + fedpydeseq2_wheel_path: str | Path | None = None, + **kwargs, +) -> dict: + """Run a federated experiment using the DESeq2 strategy. + + Parameters + ---------- + n_centers : int + Number of centers to use in the federated experiment. + + backend : BackendType + Backend to use for the experiment. Should be one of "subprocess", "docker" + or "remote". + + register_data : bool + Whether to register the data on the substra platform. Can be True only + when using the remote backend. + + simulate : bool + Whether to simulate the experiment. If True, the experiment will be simulated + and no data will be sent to the centers. This can be True only in subprocess + backend. + + asset_directory : Path + Path to the directory containing the assets (opener.py and description.md). + + centers_root_directory : Path or None + Path to the directory containing the centers data. Can be None only in remote + mode when register_data is False. + The centers data should be organized as follows: + ``` + + ├── center_0 + │ ├── counts_data.csv + │ └── metadata.csv + ├── center_1 + │ ├── counts_data.csv + │ └── metadata.csv + └── + + ``` + where the metadata.csv file is indexed by sample barcodes and contains + all columns needed to build the design matrix, and the counts_data.csv file + represents a dataframe with gene names as columns and sample barcodes as rows, + in the "barcode" column. + + compute_plan_name : str + Name of the compute plan to use for the experiment. + + dataset_name : str + Name of the dataset to fill in the Dataset schema. + + remote_timeout : int + Timeout in seconds for the remote backend. + + clean_models : bool + Whether to clean the models after the experiment. + + save_filepath : str or Path + Path to save the results of the experiment. + + credentials_path : str or Path + Path to the file containing the credentials to use for the remote backend. + + dataset_datasamples_keys_path : str or Path + Path to the file containing the datasamples keys of the dataset. + Only used for the remote backend. + Is filled in if register_data is True, and read if register_data is False. + + cp_id_path : str or Path or None + Path to the file containing the compute plan id. + This file is a yaml file with the following structure: + ``` + algo_org_name: str + credentials_path: str + compute_plan_key: str + ``` + + parameter_file : str or Path or None + If not None, yaml file containing the parameters to pass to the DESeq2Strategy. + If None, the default parameters are used. + + fedpydeseq2_wheel_path : str or Path or None + Path to the wheel file of the fedpydeseq2 package. If provided and the backend + is remote, this wheel will be added to the dependencies. + + **kwargs + Arguments to pass to the DESeq2Strategy. They will overwrite those specified + in the parameter_file if the file is not None. + + Returns + ------- + dict + Result of the strategy, which are assumed to be contained in the + results attribute of the last round of the aggregation node. + + + """ + if parameter_file is not None: + with open(parameter_file, "rb") as file: + parameters = yaml.load(file, Loader=yaml.FullLoader) + else: + parameters = {} + parameters.update(kwargs) + strategy = DESeq2Strategy(**parameters) + + return run_federated_experiment( + strategy=strategy, + n_centers=n_centers, + backend=backend, + register_data=register_data, + simulate=simulate, + centers_root_directory=centers_root_directory, + assets_directory=asset_directory, + compute_plan_name=compute_plan_name, + dataset_name=dataset_name, + remote_timeout=remote_timeout, + clean_models=clean_models, + save_filepath=save_filepath, + credentials_path=credentials_path, + dataset_datasamples_keys_path=dataset_datasamples_keys_path, + cp_id_path=cp_id_path, + fedpydeseq2_wheel_path=fedpydeseq2_wheel_path, + ) diff --git a/fedpydeseq2/substra_utils/__init__.py b/fedpydeseq2/substra_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fedpydeseq2/substra_utils/credentials/credentials-template.yaml b/fedpydeseq2/substra_utils/credentials/credentials-template.yaml new file mode 100644 index 0000000..6932887 --- /dev/null +++ b/fedpydeseq2/substra_utils/credentials/credentials-template.yaml @@ -0,0 +1,9 @@ +org1: + url: "" + token: "******" +org2: + url: "" + token: "*******" +org3: + url: "" + token: "******" diff --git a/fedpydeseq2/substra_utils/credentials/dataset-datasamples-keys-template.yaml b/fedpydeseq2/substra_utils/credentials/dataset-datasamples-keys-template.yaml new file mode 100644 index 0000000..926e50b --- /dev/null +++ b/fedpydeseq2/substra_utils/credentials/dataset-datasamples-keys-template.yaml @@ -0,0 +1,6 @@ +Org2MSP: + datasample_key: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX + dataset_key: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX +Org3MSP: + datasample_key: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX + dataset_key: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX diff --git a/fedpydeseq2/substra_utils/federated_experiment.py b/fedpydeseq2/substra_utils/federated_experiment.py new file mode 100644 index 0000000..e3c871f --- /dev/null +++ b/fedpydeseq2/substra_utils/federated_experiment.py @@ -0,0 +1,490 @@ +import pickle as pkl +import tempfile +import time +from pathlib import Path + +import numpy as np +import yaml # type: ignore +from loguru import logger +from substra.sdk.models import ComputePlanStatus +from substra.sdk.schemas import BackendType +from substra.sdk.schemas import DataSampleSpec +from substra.sdk.schemas import DatasetSpec +from substra.sdk.schemas import Permissions +from substrafl import ComputePlanBuilder +from substrafl.experiment import execute_experiment +from substrafl.experiment import simulate_experiment +from substrafl.model_loading import download_aggregate_shared_state +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode + +from fedpydeseq2.substra_utils.utils import check_datasample_folder +from fedpydeseq2.substra_utils.utils import get_client +from fedpydeseq2.substra_utils.utils import get_dependencies + + +def run_federated_experiment( + strategy: ComputePlanBuilder, + n_centers: int = 2, + backend: BackendType = "subprocess", + register_data: bool = False, + simulate: bool = True, + centers_root_directory: Path | None = None, + assets_directory: Path | None = None, + compute_plan_name: str = "FedPyDESeq2Experiment", + dataset_name: str = "TCGA", + remote_timeout: int = 86400, # 24 hours + clean_models: bool = True, + save_filepath: str | Path | None = None, + credentials_path: str | Path | None = None, + dataset_datasamples_keys_path: str | Path | None = None, + cp_id_path: str | Path | None = None, + fedpydeseq2_wheel_path: str | Path | None = None, +) -> dict: + """Run a federated experiment with the given strategy. + + In remote mode, if the data is already registered, + the assets_directory and centers_root_directory + are not used (register_data=False). + + Otherwise, the assets_directory and centers_root_directory must be + provided. The assets_directory is expected to contain the opener.py + and description.md files, used to create the dataset for all centers. + The centers_root_directory is expected to contain a subdirectory for each center, + in the following form: + + ``` + + ├── center_0 + ├── center_1 + + ``` + + These directories contain the necessary data for each center and are passed + to the DataSampleSpec object to register the data to substra. + + Parameters + ---------- + strategy : ComputePlanBuilder + The strategy to use for the federated experiment. + + n_centers : int + The number of centers to use in the experiment. + + backend : BackendType + The backend to use for the experiment. Can be one of "subprocess", + "docker", or "remote". + + register_data : bool + Whether to register the data. If True, the assets_directory and + centers_root_directory must be provided. + Can be False only in "remote" mode. + + simulate : bool + Whether to simulate the experiment. If True, the experiment must be run + in subprocess mode. + + centers_root_directory : Optional[Path] + The path to the root directory containing the data for each center. + This is only used if register_data is True. + + assets_directory : Optional[Path] + The path to the assets directory. It must contain the opener.py file + and the description.md file. This is only used if register_data is True. + + compute_plan_name : str + The name of the compute plan. + + dataset_name : str + The name of the dataset to use, to be passed to the DatasetSpec object and used + to create the path of the yaml file storing the dataset and datasample keys. + + remote_timeout : int + The timeout for the remote backend in seconds. + + clean_models : bool + Whether to clean the models after the experiment. + + save_filepath : Optional[Union[str, Path]] + The path to save the results. If None, the results are not saved. + + credentials_path : Optional[Union[str, Path]] + The path to the credentials file. By default, will be set to + Path(__file__).parent / "credentials/credentials.yaml" + This file is used only in remote mode, and is expected to be a dictionary with + the following structure: + ``` + org1: + url: "****" + token: "****" + org2: + url: "****" + token: "****" + ... + ``` + The first organization is assumed to be the algorithm provider. + The other organizations are the data providers. + + + dataset_datasamples_keys_path : Optional[Union[str, Path]] + The path to the file containing the dataset and datasamples keys. + If None, and if backend is "remote", will be set to + Path(__file__).parent / "credentials/-datasamples-keys.yaml" + This file is used only in remote mode, and is expected to be a dictionary with + the following structure: + ``` + org_id: + dataset_key: "****" + datasample_key: "****" + ... + ``` + Where all data provider org ids are present, and there is only one + datasample key per org id. + This file is generated if register_data is True and backend is "remote". + This file is loaded if register_data is False and backend is "remote". + + cp_id_path : str or Path or None + The path to a file where we save the necessary information to + retrieve the compute plan. This parameter + is only used in remote mode. + If None, this information is not saved. + If a path is provided, the information is saved in a yaml file with the + following structure: + ``` + compute_plan_key: "****" + credentials_path: "****" + algo_org_name: "****" + ``` + + fedpydeseq2_wheel_path : Optional[Union[str, Path]] + The path to the wheel file of the fedpydeseq2 package. If provided and the + backend is remote, this wheel will be added to the dependencies. + + Returns + ------- + dict + Result of the strategy, which are assumed to be contained in the + results attribute of the last round of the aggregation node. + """ + # %% + # Setup + # ***** + # In the following code cell, we define the different + # organizations needed for our FL experiment. + # Every computation will run in `subprocess` mode, + # where everything runs locally in Python + # subprocesses. + # Others backend_types are: + # "docker" mode where computations run locally in docker + # containers + # "remote" mode where computations run remotely (you need to + # have a deployed platform for that) + logger.info("Setting up organizations...") + n_clients = n_centers + 1 + if backend == "remote": + clients_ = [ + get_client( + backend_type=backend, + org_name=f"org{i}", + credentials_path=credentials_path, + ) + for i in range(1, n_clients + 1) + ] + else: + clients_ = [get_client(backend_type=backend) for _ in range(n_clients)] + + clients = { + client.organization_info().organization_id: client for client in clients_ + } + + # Store organization IDs + all_orgs_id = list(clients.keys()) + algo_org_id = all_orgs_id[0] # Algo provider is defined as the first organization. + data_providers_ids = all_orgs_id[ + 1: + ] # Data providers orgs are the remaining organizations. + + # %% + # Dataset registration + # ==================== + # + # A :ref:`documentation/concepts:Dataset` is composed of an **opener**, + # which is a Python script that can load + # the data from the files in memory and a description markdown file. + # The :ref:`documentation/concepts:Dataset` object itself does not contain + # the data. The proper asset that contains the + # data is the **datasample asset**. + # + # A **datasample** contains a local path to the data. A datasample can be + # linked to a dataset in order to add data to a + # dataset. + # + # Data privacy is a key concept for Federated Learning experiments. + # That is why we set + # :ref:`documentation/concepts:Permissions` for :ref:`documentation/concepts:Assets` + # to determine how each organization can access a specific asset. + # + # Note that metadata such as the assets' creation date and the asset owner are + # visible to all the organizations of a + # network. + + # Define the path to the asset. + if register_data: + logger.info("Registering the datasets...") + else: + logger.info("Using pre-registered datasets...") + + dataset_keys = {} + train_datasample_keys = {} + + if dataset_datasamples_keys_path is None: + dataset_datasamples_keys_path = ( + Path(__file__).parent / f"credentials/{dataset_name}-datasamples-keys.yaml" + ) + else: + dataset_datasamples_keys_path = Path(dataset_datasamples_keys_path) + + if not register_data: + # Check that we are in remote mode + assert backend == "remote", ( + "register_data must be True if backend is not remote," + "as the datasets can be saved and reused only in remote mode." + "If register_data is False, the dataset_datasamples_keys_path " + "provides the necessary information to load the data which is " + "already present on each remote organization." + ) + # Load the dataset and datasample keys from the file + with open(dataset_datasamples_keys_path) as file: + dataset_datasamples_keys = yaml.load(file, Loader=yaml.FullLoader) + for org_id in data_providers_ids: + dataset_keys[org_id] = dataset_datasamples_keys[org_id]["dataset_key"] + train_datasample_keys[org_id] = dataset_datasamples_keys[org_id][ + "datasample_key" + ] + logger.info("Datasets fetched.") + else: + for i, org_id in enumerate(data_providers_ids): + client = clients[org_id] + + # In this case, check that the assets_directory is provided + assert ( + assets_directory is not None + ), "assets_directory must be provided if register_data is True" + # In this case, check that the centers_root_directory is provided + assert centers_root_directory is not None, ( + "centers_root_directory must be provided if" "register_data is True" + ) + + permissions_dataset = Permissions(public=True, authorized_ids=all_orgs_id) + + # DatasetSpec is the specification of a dataset. It makes sure every field + # is well-defined, and that our dataset is ready to be registered. + # The real dataset object is created in the add_dataset method. + dataset = DatasetSpec( + name=dataset_name, + data_opener=assets_directory / "opener.py", + description=assets_directory / "description.md", + permissions=permissions_dataset, + logs_permission=permissions_dataset, + ) + logger.info( + f"Adding dataset to client " + f"{str(client.organization_info().organization_id)}" + ) + dataset_keys[org_id] = client.add_dataset(dataset) + logger.info(f"Dataset added. Key: {dataset_keys[org_id]}") + assert dataset_keys[org_id], "Missing dataset key" + data_sample = DataSampleSpec( + data_manager_keys=[dataset_keys[org_id]], + path=centers_root_directory / f"center_{i}", + ) + if backend == "remote": + check_datasample_folder(data_sample.path) + train_datasample_keys[org_id] = client.add_data_sample(data_sample) + + # Create the dataset and datasample keys file if the backend is remote + if backend == "remote": + dataset_datasamples_dico = { + org_id: { + "dataset_key": dataset_keys[org_id], + "datasample_key": train_datasample_keys[org_id], + } + for org_id in data_providers_ids + } + with open(dataset_datasamples_keys_path, "w") as file: + yaml.dump(dataset_datasamples_dico, file) + logger.info("Datasets registered.") + + logger.info(f"Dataset keys: {dataset_keys}") + + # %% + # Where to train where to aggregate + # ================================= + # + # We specify on which data we want to train our model, using + # the :ref:`substrafl_doc/api/nodes:TrainDataNode` object. + # Here we train on the two datasets that we have registered earlier. + # + # The :ref:`substrafl_doc/api/nodes:AggregationNode` specifies the + # organization on which the aggregation operation + # will be computed. + + aggregation_node = AggregationNode(algo_org_id) + + train_data_nodes = [] + + for org_id in data_providers_ids: + # Create the Train Data Node (or training task) and save it in a list + train_data_node = TrainDataNode( + organization_id=org_id, + data_manager_key=dataset_keys[org_id], + data_sample_keys=[train_datasample_keys[org_id]], + ) + train_data_nodes.append(train_data_node) + + # %% + # Running the experiment + # ********************** + # + # We now have all the necessary objects to launch our experiment. + # Please see a summary below of all the objects we created so far: + # + # - A :ref:`documentation/references/sdk:Client` to add or retrieve + # the assets of our experiment, using their keys to + # identify them. + # - A `Federated Strategy `_, + # implementing the pipeline that will be run. + # - `Train data nodes `_ to + # indicate on which data to train. + # - An :ref:`substrafl_doc/api/nodes:AggregationNode`, to specify the + # organization on which the aggregation operation + # will be computed. + # - An **experiment folder** to save a summary of the operation made. + # - The :ref:`substrafl_doc/api/dependency:Dependency` to define the + # libraries on which the experiment needs to run. + + # The Dependency object is instantiated in order to install the right + # libraries in the Python environment of each organization. + + algo_deps = get_dependencies( + backend_type=backend, fedpydeseq2_wheel_path=fedpydeseq2_wheel_path + ) + + exp_path = tempfile.mkdtemp() + + if simulate: + if backend != "subprocess": + raise ValueError("Simulated experiment can only be run in subprocess mode.") + _, intermediate_train_state, intermediate_state_agg = simulate_experiment( + client=clients[algo_org_id], + strategy=strategy, + train_data_nodes=train_data_nodes, + evaluation_strategy=None, + aggregation_node=aggregation_node, + clean_models=clean_models, + num_rounds=strategy.num_round, + experiment_folder=exp_path, + ) + + # Gather results from the aggregation node + + agg_client_id_mask = [ + w == clients[algo_org_id].organization_info().organization_id + for w in intermediate_state_agg.worker + ] + + agg_round_id_mask = [ + r == max(intermediate_state_agg.round_idx) + for r in intermediate_state_agg.round_idx + ] + + agg_state_idx = np.where( + [ + r and w + for r, w in zip(agg_round_id_mask, agg_client_id_mask, strict=False) + ] + )[0][0] + + fl_results = intermediate_state_agg.state[agg_state_idx].results + else: + algo_client = clients[algo_org_id] + + compute_plan = execute_experiment( + client=algo_client, + strategy=strategy, + train_data_nodes=train_data_nodes, + evaluation_strategy=None, + aggregation_node=aggregation_node, + num_rounds=strategy.num_round, + experiment_folder=exp_path, + dependencies=algo_deps, + clean_models=clean_models, + name=compute_plan_name, + ) + + compute_plan_key = compute_plan.key + + # Extract the results. The method used here downloads the results from the + # training nodes, as we cannot download + # results from the aggregation node. Note that it implies an extra step + # for the aggregation node to share the result with the training nodes. + + if cp_id_path is not None: + cp_id_path = Path(cp_id_path) + cp_id_path.parent.mkdir(parents=True, exist_ok=True) + with cp_id_path.open("w") as f: + yaml.dump( + { + "compute_plan_key": compute_plan_key, + "credentials_path": credentials_path, + "algo_org_name": "org1", + }, + f, + ) + + if backend == "remote": + sleep_time = 60 + t1 = time.time() + finished = False + while (time.time() - t1) < remote_timeout: + status = algo_client.get_compute_plan(compute_plan_key).status + logger.info( + f"Compute plan status is {status}, after {(time.time() - t1):.2f}s" + ) + if status == ComputePlanStatus.done: + logger.info("Compute plan has finished successfully") + finished = True + break + elif ( + status == ComputePlanStatus.failed + or status == ComputePlanStatus.canceled + ): + raise ValueError("Compute plan has failed") + elif ( + status == ComputePlanStatus.doing + or status == ComputePlanStatus.created + ): + pass + else: + logger.info( + f"Compute plan status is {status}, this shouldn't " + f"happen, sleeping {sleep_time} and retrying " + f"until timeout {remote_timeout}" + ) + time.sleep(sleep_time) + if not finished: + raise ValueError( + f"Compute plan did not finish after {remote_timeout} seconds" + ) + + fl_results = download_aggregate_shared_state( + client=algo_client, + compute_plan_key=compute_plan_key, + round_idx=None, + ) + if save_filepath is not None: + pkl_save_filepath = Path(save_filepath) / "fl_result.pkl" + with pkl_save_filepath.open("wb") as f: + pkl.dump(fl_results, f) + + return fl_results diff --git a/fedpydeseq2/substra_utils/utils.py b/fedpydeseq2/substra_utils/utils.py new file mode 100644 index 0000000..663a918 --- /dev/null +++ b/fedpydeseq2/substra_utils/utils.py @@ -0,0 +1,186 @@ +from pathlib import Path + +import yaml # type: ignore +from loguru import logger +from substra import BackendType +from substra import Client +from substrafl.dependency import Dependency + + +def get_client( + backend_type: BackendType, + org_name: str | None = None, + credentials_path: str | Path | None = None, +) -> Client: + """ + Return a substra client for a given organization. + + Parameters + ---------- + backend_type : str + Name of the backend to connect to. Should be "subprocess", "docker" or "remote" + org_name : str, optional. + Name of the organization to connect to. Required when using remote backend. + credentials_path : str or Path + Path to the credentials file. By default, will be set to + Path(__file__).parent / "credentials/credentials.yaml" + + """ + if backend_type not in ("subprocess", "docker", "remote"): + raise ValueError( + f"Backend type {backend_type} not supported. Should be one of 'subprocess'," + f" 'docker' or 'remote'." + ) + if backend_type == "remote": + assert ( + org_name is not None + ), "Organization name must be provided when using remote backend." + if credentials_path is not None: + credential_path = Path(credentials_path) + else: + credential_path = Path(__file__).parent / "credentials/credentials.yaml" + + with open(credential_path) as file: + conf = yaml.load(file, Loader=yaml.FullLoader) + if org_name not in conf.keys(): + raise ValueError(f"Organization {org_name} not found in credentials file.") + url = conf[org_name]["url"] + token = conf[org_name]["token"] + + logger.info( + f"Connecting to {org_name} " + f"at {url} using credentials " + f"from {credential_path}." + ) + return Client(url=url, token=token, backend_type="remote") + else: + return Client(backend_type=backend_type) + + +def cancel_compute_plan(cp_id_path: str | Path): + """ + Cancel a compute plan. + + We assume that we are in the remote setting. + + Parameters + ---------- + cp_id_path : str or Path + Path to the file containing the compute plan id. + This file is a yaml file with the following structure: + ``` + algo_org_name: str + credentials_path: str + compute_plan_key: str + ``` + """ + try: + with open(cp_id_path) as file: + conf = yaml.load(file, Loader=yaml.FullLoader) + + algo_org_name = conf["algo_org_name"] + credentials_path = conf["credentials_path"] + client = get_client( + backend_type="remote", + org_name=algo_org_name, + credentials_path=credentials_path, + ) + compute_plan_key = conf["compute_plan_key"] + client.cancel_compute_plan(compute_plan_key) + except Exception as e: # noqa : BLE001 + print( + f"An error occured while cancelling the compute plan: {e}." + f"Maybe it was already cancelled, or never launched ?" + ) + + +def get_n_centers_from_datasamples_file(datasamples_file: str | Path) -> int: + """ + Return the number of centers from a datasamples file. + + Parameters + ---------- + datasamples_file: str or Path + Path to the yaml file containing the datasamples keys of the dataset. + + Returns + ------- + int + Number of centers in the datasamples file. + + """ + with open(datasamples_file) as file: + dataset_datasamples_keys = yaml.load(file, Loader=yaml.FullLoader) + return len(dataset_datasamples_keys) + + +def get_dependencies( + backend_type: BackendType, + fedpydeseq2_wheel_path: str | Path | None = None, +) -> Dependency: + """ + Return a substra Dependency in regard to the backend_type. + + Parameters + ---------- + backend_type : BackendType + Name of the backend to connect to. Should be "subprocess", "docker" or "remote" + fedpydeseq2_wheel_path : str | Path | None, optional + Path to the wheel file of the fedpydeseq2 package. If provided and the backend + is remote or docker, this wheel will be used instead of downloading it. + + Raises + ------ + FileNotFoundError + If the wheel file cannot be downloaded or found. + """ + # in subprocess the dependency are not used, no need to build the wheel. + if backend_type == BackendType.LOCAL_SUBPROCESS: + return Dependency() + + if fedpydeseq2_wheel_path: + wheel_path = Path(fedpydeseq2_wheel_path) + if not wheel_path.exists(): + raise FileNotFoundError(f"Provided wheel file not found: {wheel_path}") + logger.info(f"Using provided wheel path: {wheel_path}") + return Dependency(local_installable_dependencies=[wheel_path]) + else: + raise FileNotFoundError( + "You must provide a wheel path when using a remote backend." + ) + + +def check_datasample_folder(datasample_folder: Path) -> None: + """ + Sanity check for the datasample folder. + + Check if the datasample folder contains only two csv files: counts_data.csv + and metadata.csv and nothing else. + + Parameters + ---------- + datasample_folder : Path + Path to the datasample folder. + + Raises + ------ + ValueError + If the datasample folder does not contain exactly two files named + 'counts_data.csv' and 'metadata.csv'. + + """ + if not datasample_folder.is_dir(): + raise ValueError(f"{datasample_folder} is not a directory.") + files = list(datasample_folder.iterdir()) + if len(files) != 2: + raise ValueError( + "Datasample folder should contain exactly two files, " + f"found {len(files)}: {files}." + ) + if {file.name for file in files} != {"counts_data.csv", "metadata.csv"}: + raise ValueError( + "Datasample folder should contain two csv files named 'counts_data.csv'" + " and 'metadata.csv'." + ) + + return diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..82439bd --- /dev/null +++ b/poetry.lock @@ -0,0 +1,2164 @@ +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. + +[[package]] +name = "anndata" +version = "0.10.8" +description = "Annotated data." +optional = false +python-versions = ">=3.9" +files = [ + {file = "anndata-0.10.8-py3-none-any.whl", hash = "sha256:1b24934dc2674eaf3072cb7010e187aa2b2f4f0e4cf0a32ffeab5ffebe3b1415"}, + {file = "anndata-0.10.8.tar.gz", hash = "sha256:b728a33225eeaaefddf6bed546d935c0f06881c9166621b24de3b492b2f406bb"}, +] + +[package.dependencies] +array-api-compat = ">1.4,<1.5 || >1.5" +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +h5py = ">=3.1" +natsort = "*" +numpy = ">=1.23" +packaging = ">=20.0" +pandas = ">=1.4,<2.1.0rc0 || >2.1.0rc0,<2.1.2 || >2.1.2" +scipy = ">1.8" + +[package.extras] +dev = ["pytest-xdist", "setuptools-scm"] +doc = ["awkward (>=2.0.7)", "ipython", "myst-parser", "nbsphinx", "readthedocs-sphinx-search", "scanpydoc[theme,typehints] (>=0.13.4)", "sphinx (>=4.4)", "sphinx-autodoc-typehints (>=1.11.0)", "sphinx-book-theme (>=1.1.0)", "sphinx-copybutton", "sphinx-design (>=0.5.0)", "sphinx-issues", "sphinxext-opengraph", "zarr"] +gpu = ["cupy"] +test = ["awkward (>=2.3)", "boltons", "dask[array,distributed] (>=2022.09.2)", "httpx", "joblib", "loompy (>=3.0.5)", "matplotlib", "openpyxl", "pyarrow", "pytest (>=8.2)", "pytest-cov (>=2.10)", "pytest-memray", "pytest-mock", "scanpy", "scikit-learn", "zarr (<3.0.0a0)"] + +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + +[[package]] +name = "array-api-compat" +version = "1.9.1" +description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" +optional = false +python-versions = ">=3.9" +files = [ + {file = "array_api_compat-1.9.1-py3-none-any.whl", hash = "sha256:41a2703a662832d21619359ddddc5c0449876871f6c01e108c335f2a9432df94"}, + {file = "array_api_compat-1.9.1.tar.gz", hash = "sha256:17bab828c93c79a5bb8b867145b71fcb889686607c5672b060aef437e0359ea8"}, +] + +[package.extras] +cupy = ["cupy"] +dask = ["dask"] +jax = ["jax"] +numpy = ["numpy"] +pytorch = ["pytorch"] +sparse = ["sparse (>=0.15.1)"] + +[[package]] +name = "black" +version = "24.10.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.9" +files = [ + {file = "black-24.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6668650ea4b685440857138e5fe40cde4d652633b1bdffc62933d0db4ed9812"}, + {file = "black-24.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1c536fcf674217e87b8cc3657b81809d3c085d7bf3ef262ead700da345bfa6ea"}, + {file = "black-24.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:649fff99a20bd06c6f727d2a27f401331dc0cc861fb69cde910fe95b01b5928f"}, + {file = "black-24.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:fe4d6476887de70546212c99ac9bd803d90b42fc4767f058a0baa895013fbb3e"}, + {file = "black-24.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5a2221696a8224e335c28816a9d331a6c2ae15a2ee34ec857dcf3e45dbfa99ad"}, + {file = "black-24.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f9da3333530dbcecc1be13e69c250ed8dfa67f43c4005fb537bb426e19200d50"}, + {file = "black-24.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4007b1393d902b48b36958a216c20c4482f601569d19ed1df294a496eb366392"}, + {file = "black-24.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:394d4ddc64782e51153eadcaaca95144ac4c35e27ef9b0a42e121ae7e57a9175"}, + {file = "black-24.10.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b5e39e0fae001df40f95bd8cc36b9165c5e2ea88900167bddf258bacef9bbdc3"}, + {file = "black-24.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d37d422772111794b26757c5b55a3eade028aa3fde43121ab7b673d050949d65"}, + {file = "black-24.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14b3502784f09ce2443830e3133dacf2c0110d45191ed470ecb04d0f5f6fcb0f"}, + {file = "black-24.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:30d2c30dc5139211dda799758559d1b049f7f14c580c409d6ad925b74a4208a8"}, + {file = "black-24.10.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1cbacacb19e922a1d75ef2b6ccaefcd6e93a2c05ede32f06a21386a04cedb981"}, + {file = "black-24.10.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1f93102e0c5bb3907451063e08b9876dbeac810e7da5a8bfb7aeb5a9ef89066b"}, + {file = "black-24.10.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ddacb691cdcdf77b96f549cf9591701d8db36b2f19519373d60d31746068dbf2"}, + {file = "black-24.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:680359d932801c76d2e9c9068d05c6b107f2584b2a5b88831c83962eb9984c1b"}, + {file = "black-24.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:17374989640fbca88b6a448129cd1745c5eb8d9547b464f281b251dd00155ccd"}, + {file = "black-24.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:63f626344343083322233f175aaf372d326de8436f5928c042639a4afbbf1d3f"}, + {file = "black-24.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfa1d0cb6200857f1923b602f978386a3a2758a65b52e0950299ea014be6800"}, + {file = "black-24.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cd9c95431d94adc56600710f8813ee27eea544dd118d45896bb734e9d7a0dc7"}, + {file = "black-24.10.0-py3-none-any.whl", hash = "sha256:3bb2b7a1f7b685f85b11fed1ef10f8a9148bceb49853e47a294a3dd963c1dd7d"}, + {file = "black-24.10.0.tar.gz", hash = "sha256:846ea64c97afe3bc677b761787993be4991810ecc7a4a937816dd6bddedc4875"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.10)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "build" +version = "1.2.2.post1" +description = "A simple, correct Python build frontend" +optional = false +python-versions = ">=3.8" +files = [ + {file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"}, + {file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "os_name == \"nt\""} +importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""} +packaging = ">=19.1" +pyproject_hooks = "*" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} + +[package.extras] +docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] +test = ["build[uv,virtualenv]", "filelock (>=3)", "pytest (>=6.2.4)", "pytest-cov (>=2.12)", "pytest-mock (>=2)", "pytest-rerunfailures (>=9.1)", "pytest-xdist (>=1.34)", "setuptools (>=42.0.0)", "setuptools (>=56.0.0)", "setuptools (>=56.0.0)", "setuptools (>=67.8.0)", "wheel (>=0.36.0)"] +typing = ["build[uv]", "importlib-metadata (>=5.1)", "mypy (>=1.9.0,<1.10.0)", "tomli", "typing-extensions (>=3.7.4.3)"] +uv = ["uv (>=0.1.18)"] +virtualenv = ["virtualenv (>=20.0.35)"] + +[[package]] +name = "certifi" +version = "2024.8.30" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"}, + {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"}, +] + +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.0" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4f9fc98dad6c2eaa32fc3af1417d95b5e3d08aff968df0cd320066def971f9a6"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0de7b687289d3c1b3e8660d0741874abe7888100efe14bd0f9fd7141bcbda92b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5ed2e36c3e9b4f21dd9422f6893dec0abf2cca553af509b10cd630f878d3eb99"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d3ff7fc90b98c637bda91c89d51264a3dcf210cade3a2c6f838c7268d7a4ca"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1110e22af8ca26b90bd6364fe4c763329b0ebf1ee213ba32b68c73de5752323d"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86f4e8cca779080f66ff4f191a685ced73d2f72d50216f7112185dc02b90b9b7"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f683ddc7eedd742e2889d2bfb96d69573fde1d92fcb811979cdb7165bb9c7d3"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27623ba66c183eca01bf9ff833875b459cad267aeeb044477fedac35e19ba907"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f606a1881d2663630ea5b8ce2efe2111740df4b687bd78b34a8131baa007f79b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0b309d1747110feb25d7ed6b01afdec269c647d382c857ef4663bbe6ad95a912"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:136815f06a3ae311fae551c3df1f998a1ebd01ddd424aa5603a4336997629e95"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:14215b71a762336254351b00ec720a8e85cada43b987da5a042e4ce3e82bd68e"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:79983512b108e4a164b9c8d34de3992f76d48cadc9554c9e60b43f308988aabe"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win32.whl", hash = "sha256:c94057af19bc953643a33581844649a7fdab902624d2eb739738a30e2b3e60fc"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55f56e2ebd4e3bc50442fbc0888c9d8c94e4e06a933804e2af3e89e2f9c1c749"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0d99dd8ff461990f12d6e42c7347fd9ab2532fb70e9621ba520f9e8637161d7c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c57516e58fd17d03ebe67e181a4e4e2ccab1168f8c2976c6a334d4f819fe5944"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dba5d19c4dfab08e58d5b36304b3f92f3bd5d42c1a3fa37b5ba5cdf6dfcbcee"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf4475b82be41b07cc5e5ff94810e6a01f276e37c2d55571e3fe175e467a1a1c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce031db0408e487fd2775d745ce30a7cd2923667cf3b69d48d219f1d8f5ddeb6"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ff4e7cdfdb1ab5698e675ca622e72d58a6fa2a8aa58195de0c0061288e6e3ea"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3710a9751938947e6327ea9f3ea6332a09bf0ba0c09cae9cb1f250bd1f1549bc"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82357d85de703176b5587dbe6ade8ff67f9f69a41c0733cf2425378b49954de5"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47334db71978b23ebcf3c0f9f5ee98b8d65992b65c9c4f2d34c2eaf5bcaf0594"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8ce7fd6767a1cc5a92a639b391891bf1c268b03ec7e021c7d6d902285259685c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f1a2f519ae173b5b6a2c9d5fa3116ce16e48b3462c8b96dfdded11055e3d6365"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:63bc5c4ae26e4bc6be6469943b8253c0fd4e4186c43ad46e713ea61a0ba49129"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bcb4f8ea87d03bc51ad04add8ceaf9b0f085ac045ab4d74e73bbc2dc033f0236"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win32.whl", hash = "sha256:9ae4ef0b3f6b41bad6366fb0ea4fc1d7ed051528e113a60fa2a65a9abb5b1d99"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cee4373f4d3ad28f1ab6290684d8e2ebdb9e7a1b74fdc39e4c211995f77bec27"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0713f3adb9d03d49d365b70b84775d0a0d18e4ab08d12bc46baa6132ba78aaf6"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:de7376c29d95d6719048c194a9cf1a1b0393fbe8488a22008610b0361d834ecf"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a51b48f42d9358460b78725283f04bddaf44a9358197b889657deba38f329db"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b295729485b06c1a0683af02a9e42d2caa9db04a373dc38a6a58cdd1e8abddf1"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee803480535c44e7f5ad00788526da7d85525cfefaf8acf8ab9a310000be4b03"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d59d125ffbd6d552765510e3f31ed75ebac2c7470c7274195b9161a32350284"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cda06946eac330cbe6598f77bb54e690b4ca93f593dee1568ad22b04f347c15"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07afec21bbbbf8a5cc3651aa96b980afe2526e7f048fdfb7f1014d84acc8b6d8"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6b40e8d38afe634559e398cc32b1472f376a4099c75fe6299ae607e404c033b2"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b8dcd239c743aa2f9c22ce674a145e0a25cb1566c495928440a181ca1ccf6719"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:84450ba661fb96e9fd67629b93d2941c871ca86fc38d835d19d4225ff946a631"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44aeb140295a2f0659e113b31cfe92c9061622cadbc9e2a2f7b8ef6b1e29ef4b"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1db4e7fefefd0f548d73e2e2e041f9df5c59e178b4c72fbac4cc6f535cfb1565"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win32.whl", hash = "sha256:5726cf76c982532c1863fb64d8c6dd0e4c90b6ece9feb06c9f202417a31f7dd7"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b197e7094f232959f8f20541ead1d9862ac5ebea1d58e9849c1bf979255dfac9"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dbe03226baf438ac4fda9e2d0715022fd579cb641c4cf639fa40d53b2fe6f3e2"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd9a8bd8900e65504a305bf8ae6fa9fbc66de94178c420791d0293702fce2df7"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8831399554b92b72af5932cdbbd4ddc55c55f631bb13ff8fe4e6536a06c5c51"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a14969b8691f7998e74663b77b4c36c0337cb1df552da83d5c9004a93afdb574"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaf7c1524c0542ee2fc82cc8ec337f7a9f7edee2532421ab200d2b920fc97cf"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425c5f215d0eecee9a56cdb703203dda90423247421bf0d67125add85d0c4455"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:d5b054862739d276e09928de37c79ddeec42a6e1bfc55863be96a36ba22926f6"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:f3e73a4255342d4eb26ef6df01e3962e73aa29baa3124a8e824c5d3364a65748"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:2f6c34da58ea9c1a9515621f4d9ac379871a8f21168ba1b5e09d74250de5ad62"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f09cb5a7bbe1ecae6e87901a2eb23e0256bb524a79ccc53eb0b7629fbe7677c4"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:0099d79bdfcf5c1f0c2c72f91516702ebf8b0b8ddd8905f97a8aecf49712c621"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win32.whl", hash = "sha256:9c98230f5042f4945f957d006edccc2af1e03ed5e37ce7c373f00a5a4daa6149"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62f60aebecfc7f4b82e3f639a7d1433a20ec32824db2199a11ad4f5e146ef5ee"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:af73657b7a68211996527dbfeffbb0864e043d270580c5aef06dc4b659a4b578"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cab5d0b79d987c67f3b9e9c53f54a61360422a5a0bc075f43cab5621d530c3b6"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9289fd5dddcf57bab41d044f1756550f9e7cf0c8e373b8cdf0ce8773dc4bd417"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b493a043635eb376e50eedf7818f2f322eabbaa974e948bd8bdd29eb7ef2a51"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fa2566ca27d67c86569e8c85297aaf413ffab85a8960500f12ea34ff98e4c41"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8e538f46104c815be19c975572d74afb53f29650ea2025bbfaef359d2de2f7f"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fd30dc99682dc2c603c2b315bded2799019cea829f8bf57dc6b61efde6611c8"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2006769bd1640bdf4d5641c69a3d63b71b81445473cac5ded39740a226fa88ab"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc15e99b2d8a656f8e666854404f1ba54765871104e50c8e9813af8a7db07f12"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ab2e5bef076f5a235c3774b4f4028a680432cded7cad37bba0fd90d64b187d19"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:4ec9dd88a5b71abfc74e9df5ebe7921c35cbb3b641181a531ca65cdb5e8e4dea"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:43193c5cda5d612f247172016c4bb71251c784d7a4d9314677186a838ad34858"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:aa693779a8b50cd97570e5a0f343538a8dbd3e496fa5dcb87e29406ad0299654"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win32.whl", hash = "sha256:7706f5850360ac01d80c89bcef1640683cc12ed87f42579dab6c5d3ed6888613"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:c3e446d253bd88f6377260d07c895816ebf33ffffd56c1c792b13bff9c3e1ade"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:980b4f289d1d90ca5efcf07958d3eb38ed9c0b7676bf2831a54d4f66f9c27dfa"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f28f891ccd15c514a0981f3b9db9aa23d62fe1a99997512b0491d2ed323d229a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8aacce6e2e1edcb6ac625fb0f8c3a9570ccc7bfba1f63419b3769ccf6a00ed0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7af3717683bea4c87acd8c0d3d5b44d56120b26fd3f8a692bdd2d5260c620a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff2ed8194587faf56555927b3aa10e6fb69d931e33953943bc4f837dfee2242"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e91f541a85298cf35433bf66f3fab2a4a2cff05c127eeca4af174f6d497f0d4b"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309a7de0a0ff3040acaebb35ec45d18db4b28232f21998851cfa709eeff49d62"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:285e96d9d53422efc0d7a17c60e59f37fbf3dfa942073f666db4ac71e8d726d0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5d447056e2ca60382d460a604b6302d8db69476fd2015c81e7c35417cfabe4cd"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:20587d20f557fe189b7947d8e7ec5afa110ccf72a3128d61a2a387c3313f46be"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:130272c698667a982a5d0e626851ceff662565379baf0ff2cc58067b81d4f11d"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ab22fbd9765e6954bc0bcff24c25ff71dcbfdb185fcdaca49e81bac68fe724d3"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7782afc9b6b42200f7362858f9e73b1f8316afb276d316336c0ec3bd73312742"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win32.whl", hash = "sha256:2de62e8801ddfff069cd5c504ce3bc9672b23266597d4e4f50eda28846c322f2"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:95c3c157765b031331dd4db3c775e58deaee050a3042fcad72cbc4189d7c8dca"}, + {file = "charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079"}, + {file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"}, +] + +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "cloudpickle" +version = "3.1.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e"}, + {file = "cloudpickle-3.1.0.tar.gz", hash = "sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "contourpy" +version = "1.3.1" +description = "Python library for calculating contours of 2D quadrilateral grids" +optional = false +python-versions = ">=3.10" +files = [ + {file = "contourpy-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a045f341a77b77e1c5de31e74e966537bba9f3c4099b35bf4c2e3939dd54cdab"}, + {file = "contourpy-1.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:500360b77259914f7805af7462e41f9cb7ca92ad38e9f94d6c8641b089338124"}, + {file = "contourpy-1.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2f926efda994cdf3c8d3fdb40b9962f86edbc4457e739277b961eced3d0b4c1"}, + {file = "contourpy-1.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:adce39d67c0edf383647a3a007de0a45fd1b08dedaa5318404f1a73059c2512b"}, + {file = "contourpy-1.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abbb49fb7dac584e5abc6636b7b2a7227111c4f771005853e7d25176daaf8453"}, + {file = "contourpy-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0cffcbede75c059f535725c1680dfb17b6ba8753f0c74b14e6a9c68c29d7ea3"}, + {file = "contourpy-1.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ab29962927945d89d9b293eabd0d59aea28d887d4f3be6c22deaefbb938a7277"}, + {file = "contourpy-1.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:974d8145f8ca354498005b5b981165b74a195abfae9a8129df3e56771961d595"}, + {file = "contourpy-1.3.1-cp310-cp310-win32.whl", hash = "sha256:ac4578ac281983f63b400f7fe6c101bedc10651650eef012be1ccffcbacf3697"}, + {file = "contourpy-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:174e758c66bbc1c8576992cec9599ce8b6672b741b5d336b5c74e35ac382b18e"}, + {file = "contourpy-1.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8b974d8db2c5610fb4e76307e265de0edb655ae8169e8b21f41807ccbeec4b"}, + {file = "contourpy-1.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20914c8c973f41456337652a6eeca26d2148aa96dd7ac323b74516988bea89fc"}, + {file = "contourpy-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d40d37c1c3a4961b4619dd9d77b12124a453cc3d02bb31a07d58ef684d3d86"}, + {file = "contourpy-1.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:113231fe3825ebf6f15eaa8bc1f5b0ddc19d42b733345eae0934cb291beb88b6"}, + {file = "contourpy-1.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dbbc03a40f916a8420e420d63e96a1258d3d1b58cbdfd8d1f07b49fcbd38e85"}, + {file = "contourpy-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a04ecd68acbd77fa2d39723ceca4c3197cb2969633836ced1bea14e219d077c"}, + {file = "contourpy-1.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c414fc1ed8ee1dbd5da626cf3710c6013d3d27456651d156711fa24f24bd1291"}, + {file = "contourpy-1.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:31c1b55c1f34f80557d3830d3dd93ba722ce7e33a0b472cba0ec3b6535684d8f"}, + {file = "contourpy-1.3.1-cp311-cp311-win32.whl", hash = "sha256:f611e628ef06670df83fce17805c344710ca5cde01edfdc72751311da8585375"}, + {file = "contourpy-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:b2bdca22a27e35f16794cf585832e542123296b4687f9fd96822db6bae17bfc9"}, + {file = "contourpy-1.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0ffa84be8e0bd33410b17189f7164c3589c229ce5db85798076a3fa136d0e509"}, + {file = "contourpy-1.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805617228ba7e2cbbfb6c503858e626ab528ac2a32a04a2fe88ffaf6b02c32bc"}, + {file = "contourpy-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade08d343436a94e633db932e7e8407fe7de8083967962b46bdfc1b0ced39454"}, + {file = "contourpy-1.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:47734d7073fb4590b4a40122b35917cd77be5722d80683b249dac1de266aac80"}, + {file = "contourpy-1.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2ba94a401342fc0f8b948e57d977557fbf4d515f03c67682dd5c6191cb2d16ec"}, + {file = "contourpy-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efa874e87e4a647fd2e4f514d5e91c7d493697127beb95e77d2f7561f6905bd9"}, + {file = "contourpy-1.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1bf98051f1045b15c87868dbaea84f92408337d4f81d0e449ee41920ea121d3b"}, + {file = "contourpy-1.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:61332c87493b00091423e747ea78200659dc09bdf7fd69edd5e98cef5d3e9a8d"}, + {file = "contourpy-1.3.1-cp312-cp312-win32.whl", hash = "sha256:e914a8cb05ce5c809dd0fe350cfbb4e881bde5e2a38dc04e3afe1b3e58bd158e"}, + {file = "contourpy-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:08d9d449a61cf53033612cb368f3a1b26cd7835d9b8cd326647efe43bca7568d"}, + {file = "contourpy-1.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a761d9ccfc5e2ecd1bf05534eda382aa14c3e4f9205ba5b1684ecfe400716ef2"}, + {file = "contourpy-1.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:523a8ee12edfa36f6d2a49407f705a6ef4c5098de4f498619787e272de93f2d5"}, + {file = "contourpy-1.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece6df05e2c41bd46776fbc712e0996f7c94e0d0543af1656956d150c4ca7c81"}, + {file = "contourpy-1.3.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:573abb30e0e05bf31ed067d2f82500ecfdaec15627a59d63ea2d95714790f5c2"}, + {file = "contourpy-1.3.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9fa36448e6a3a1a9a2ba23c02012c43ed88905ec80163f2ffe2421c7192a5d7"}, + {file = "contourpy-1.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ea9924d28fc5586bf0b42d15f590b10c224117e74409dd7a0be3b62b74a501c"}, + {file = "contourpy-1.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5b75aa69cb4d6f137b36f7eb2ace9280cfb60c55dc5f61c731fdf6f037f958a3"}, + {file = "contourpy-1.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:041b640d4ec01922083645a94bb3b2e777e6b626788f4095cf21abbe266413c1"}, + {file = "contourpy-1.3.1-cp313-cp313-win32.whl", hash = "sha256:36987a15e8ace5f58d4d5da9dca82d498c2bbb28dff6e5d04fbfcc35a9cb3a82"}, + {file = "contourpy-1.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:a7895f46d47671fa7ceec40f31fae721da51ad34bdca0bee83e38870b1f47ffd"}, + {file = "contourpy-1.3.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9ddeb796389dadcd884c7eb07bd14ef12408aaae358f0e2ae24114d797eede30"}, + {file = "contourpy-1.3.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:19c1555a6801c2f084c7ddc1c6e11f02eb6a6016ca1318dd5452ba3f613a1751"}, + {file = "contourpy-1.3.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:841ad858cff65c2c04bf93875e384ccb82b654574a6d7f30453a04f04af71342"}, + {file = "contourpy-1.3.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4318af1c925fb9a4fb190559ef3eec206845f63e80fb603d47f2d6d67683901c"}, + {file = "contourpy-1.3.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:14c102b0eab282427b662cb590f2e9340a9d91a1c297f48729431f2dcd16e14f"}, + {file = "contourpy-1.3.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05e806338bfeaa006acbdeba0ad681a10be63b26e1b17317bfac3c5d98f36cda"}, + {file = "contourpy-1.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4d76d5993a34ef3df5181ba3c92fabb93f1eaa5729504fb03423fcd9f3177242"}, + {file = "contourpy-1.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:89785bb2a1980c1bd87f0cb1517a71cde374776a5f150936b82580ae6ead44a1"}, + {file = "contourpy-1.3.1-cp313-cp313t-win32.whl", hash = "sha256:8eb96e79b9f3dcadbad2a3891672f81cdcab7f95b27f28f1c67d75f045b6b4f1"}, + {file = "contourpy-1.3.1-cp313-cp313t-win_amd64.whl", hash = "sha256:287ccc248c9e0d0566934e7d606201abd74761b5703d804ff3df8935f523d546"}, + {file = "contourpy-1.3.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b457d6430833cee8e4b8e9b6f07aa1c161e5e0d52e118dc102c8f9bd7dd060d6"}, + {file = "contourpy-1.3.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb76c1a154b83991a3cbbf0dfeb26ec2833ad56f95540b442c73950af2013750"}, + {file = "contourpy-1.3.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:44a29502ca9c7b5ba389e620d44f2fbe792b1fb5734e8b931ad307071ec58c53"}, + {file = "contourpy-1.3.1.tar.gz", hash = "sha256:dfd97abd83335045a913e3bcc4a09c0ceadbe66580cf573fe961f4a825efa699"}, +] + +[package.dependencies] +numpy = ">=1.23" + +[package.extras] +bokeh = ["bokeh", "selenium"] +docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.11.1)", "types-Pillow"] +test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] +test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] + +[[package]] +name = "cycler" +version = "0.12.1" +description = "Composable style cycles" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, + {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, +] + +[package.extras] +docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] +tests = ["pytest", "pytest-cov", "pytest-xdist"] + +[[package]] +name = "distlib" +version = "0.3.9" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87"}, + {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, +] + +[[package]] +name = "docker" +version = "7.1.0" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.8" +files = [ + {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"}, + {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"}, +] + +[package.dependencies] +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" + +[package.extras] +dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"] +docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] +ssh = ["paramiko (>=2.4.3)"] +websockets = ["websocket-client (>=1.3.0)"] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "fedpydeseq2-datasets" +version = "0.1.0" +description = "This package contains utilities to process TCGA data for fedpydeseq2." +optional = false +python-versions = "<3.13,>=3.10" +files = [ + {file = "fedpydeseq2_datasets-0.1.0-py3-none-any.whl", hash = "sha256:33e29de78ce2680653a3d22fbf4a15e317ae20e5a4d32a6b5de8df197e5015dd"}, + {file = "fedpydeseq2_datasets-0.1.0.tar.gz", hash = "sha256:d9af5fd15bf1f87a9fdcf0ccd0178967b96f9c9adacd0eb0abcc9c9add872c03"}, +] + +[package.dependencies] +anndata = "0.10.8" +gitpython = "3.1.43" +loguru = "0.7.2" +numpy = "1.26.4" +pandas = "2.2.2" +pyarrow = "15.0.2" +pydeseq2 = "0.4.9" +pyyaml = ">=5.1" + +[[package]] +name = "filelock" +version = "3.16.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +typing = ["typing-extensions (>=4.12.2)"] + +[[package]] +name = "fonttools" +version = "4.55.1" +description = "Tools to manipulate font files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fonttools-4.55.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c17a6f9814f83772cd6d9c9009928e1afa4ab66210a31ced721556651075a9a0"}, + {file = "fonttools-4.55.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c4d14eecc814826a01db87a40af3407c892ba49996bc6e49961e386cd78b537c"}, + {file = "fonttools-4.55.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8589f9a15dc005592b94ecdc45b4dfae9bbe9e73542e89af5a5e776e745db83b"}, + {file = "fonttools-4.55.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfee95bd9395bcd9e6c78955387554335109b6a613db71ef006020b42f761c58"}, + {file = "fonttools-4.55.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:34fa2ecc0bf1923d1a51bf2216a006de2c3c0db02c6aa1470ea50b62b8619bd5"}, + {file = "fonttools-4.55.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9c1c48483148bfb1b9ad951133ceea957faa004f6cb475b67e7bc75d482b48f8"}, + {file = "fonttools-4.55.1-cp310-cp310-win32.whl", hash = "sha256:3e2fc388ca7d023b3c45badd71016fd4185f93e51a22cfe4bd65378af7fba759"}, + {file = "fonttools-4.55.1-cp310-cp310-win_amd64.whl", hash = "sha256:c4c36c71f69d2b3ee30394b0986e5f8b2c461e7eff48dde49b08a90ded9fcdbd"}, + {file = "fonttools-4.55.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5daab3a55d460577f45bb8f5a8eca01fa6cde43ef2ab943b527991f54b735c41"}, + {file = "fonttools-4.55.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:acf1e80cf96c2fbc79e46f669d8713a9a79faaebcc68e31a9fbe600cf8027992"}, + {file = "fonttools-4.55.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e88a0329f7f88a210f09f79c088fb64f8032fc3ab65e2390a40b7d3a11773026"}, + {file = "fonttools-4.55.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03105b42259a8a94b2f0cbf1bee45f7a8a34e7b26c946a8fb89b4967e44091a8"}, + {file = "fonttools-4.55.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9af3577e821649879ab5774ad0e060af34816af556c77c6d3820345d12bf415e"}, + {file = "fonttools-4.55.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:34bd5de3d0ad085359b79a96575cd6bd1bc2976320ef24a2aa152ead36dbf656"}, + {file = "fonttools-4.55.1-cp311-cp311-win32.whl", hash = "sha256:5da92c4b637f0155a41f345fa81143c8e17425260fcb21521cb2ad4d2cea2a95"}, + {file = "fonttools-4.55.1-cp311-cp311-win_amd64.whl", hash = "sha256:f70234253d15f844e6da1178f019a931f03181463ce0c7b19648b8c370527b07"}, + {file = "fonttools-4.55.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9c372e527d58ba64b695f15f8014e97bc8826cf64d3380fc89b4196edd3c0fa8"}, + {file = "fonttools-4.55.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:845a967d3bef3245ba81fb5582dc731f6c2c8417fa211f1068c56893504bc000"}, + {file = "fonttools-4.55.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f03be82bcd4ba4418adf10e6165743f824bb09d6594c2743d7f93ea50968805b"}, + {file = "fonttools-4.55.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c42e935cf146f826f556d977660dac88f2fa3fb2efa27d5636c0b89a60c16edf"}, + {file = "fonttools-4.55.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:96328bf91e05621d8e40d9f854af7a262cb0e8313e9b38e7f3a7f3c4c0caaa8b"}, + {file = "fonttools-4.55.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:291acec4d774e8cd2d8472d88c04643a77a3324a15247951bd6cfc969799b69e"}, + {file = "fonttools-4.55.1-cp312-cp312-win32.whl", hash = "sha256:6d768d6632809aec1c3fa8f195b173386d85602334701a6894a601a4d3c80368"}, + {file = "fonttools-4.55.1-cp312-cp312-win_amd64.whl", hash = "sha256:2a3850afdb0be1f79a1e95340a2059226511675c5b68098d4e49bfbeb48a8aab"}, + {file = "fonttools-4.55.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:0c88d427eaf8bd8497b9051f56e0f5f9fb96a311aa7c72cda35e03e18d59cd16"}, + {file = "fonttools-4.55.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f062c95a725a79fd908fe8407b6ad63e230e1c7d6dece2d5d6ecaf843d6927f6"}, + {file = "fonttools-4.55.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f298c5324c45cad073475146bf560f4110ce2dc2488ff12231a343ec489f77bc"}, + {file = "fonttools-4.55.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f06dbb71344ffd85a6cb7e27970a178952f0bdd8d319ed938e64ba4bcc41700"}, + {file = "fonttools-4.55.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4c46b3525166976f5855b1f039b02433dc51eb635fb54d6a111e0c5d6e6cdc4c"}, + {file = "fonttools-4.55.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:af46f52a21e086a2f89b87bd941c9f0f91e5f769e1a5eb3b37c912228814d3e5"}, + {file = "fonttools-4.55.1-cp313-cp313-win32.whl", hash = "sha256:cd7f36335c5725a3fd724cc667c10c3f5254e779bdc5bffefebb33cf5a75ecb1"}, + {file = "fonttools-4.55.1-cp313-cp313-win_amd64.whl", hash = "sha256:5d6394897710ccac7f74df48492d7f02b9586ff0588c66a2c218844e90534b22"}, + {file = "fonttools-4.55.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:52c4f4b383c56e1a4fe8dab1b63c2269ba9eab0695d2d8e033fa037e61e6f1ef"}, + {file = "fonttools-4.55.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d83892dafdbd62b56545c77b6bd4fa49eef6ec1d6b95e042ee2c930503d1831e"}, + {file = "fonttools-4.55.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604d5bf16f811fcaaaec2dde139f7ce958462487565edcd54b6fadacb2942083"}, + {file = "fonttools-4.55.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3324b92feb5fd084923a8e89a8248afd5b9f9d81ab9517d7b07cc84403bd448"}, + {file = "fonttools-4.55.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:30f8b1ca9b919c04850678d026fc330c19acaa9e3b282fcacc09a5eb3c8d20c3"}, + {file = "fonttools-4.55.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:1835c98df2cf28c86a66d234895c87df7b9325fd079a8019c5053a389ff55d23"}, + {file = "fonttools-4.55.1-cp38-cp38-win32.whl", hash = "sha256:9f202703720a7cc0049f2ed1a2047925e264384eb5cc4d34f80200d7b17f1b6a"}, + {file = "fonttools-4.55.1-cp38-cp38-win_amd64.whl", hash = "sha256:2efff20aed0338d37c2ff58766bd67f4b9607ded61cf3d6baf1b3e25ea74e119"}, + {file = "fonttools-4.55.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3032d9bf010c395e6eca2851666cafb1f4ecde85d420188555e928ad0144326e"}, + {file = "fonttools-4.55.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0794055588c30ffe25426048e8a7c0a5271942727cd61fc939391e37f4d580d5"}, + {file = "fonttools-4.55.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13ba980e3ffd3206b8c63a365f90dc10eeec27da946d5ee5373c3a325a46d77c"}, + {file = "fonttools-4.55.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d7063babd7434a17a5e355e87de9b2306c85a5c19c7da0794be15c58aab0c39"}, + {file = "fonttools-4.55.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ed84c15144015a58ef550dd6312884c9fb31a2dbc31a6467bcdafd63be7db476"}, + {file = "fonttools-4.55.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e89419d88b0bbfdb55209e03a17afa2d20db3c2fa0d785543c9d0875668195d5"}, + {file = "fonttools-4.55.1-cp39-cp39-win32.whl", hash = "sha256:6eb781e401b93cda99356bc043ababead2a5096550984d8a4ecf3d5c9f859dc2"}, + {file = "fonttools-4.55.1-cp39-cp39-win_amd64.whl", hash = "sha256:db1031acf04523c5a51c3e1ae19c21a1c32bc5f820a477dd4659a02f9cb82002"}, + {file = "fonttools-4.55.1-py3-none-any.whl", hash = "sha256:4bcfb11f90f48b48c366dd638d773a52fca0d1b9e056dc01df766bf5835baa08"}, + {file = "fonttools-4.55.1.tar.gz", hash = "sha256:85bb2e985718b0df96afc659abfe194c171726054314b019dbbfed31581673c7"}, +] + +[package.extras] +all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +graphite = ["lz4 (>=1.7.4.2)"] +interpolatable = ["munkres", "pycairo", "scipy"] +lxml = ["lxml (>=4.0)"] +pathops = ["skia-pathops (>=0.5.0)"] +plot = ["matplotlib"] +repacker = ["uharfbuzz (>=0.23.0)"] +symfont = ["sympy"] +type1 = ["xattr"] +ufo = ["fs (>=2.2.0,<3)"] +unicode = ["unicodedata2 (>=15.1.0)"] +woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] + +[[package]] +name = "gitdb" +version = "4.0.11" +description = "Git Object Database" +optional = false +python-versions = ">=3.7" +files = [ + {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"}, + {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"}, +] + +[package.dependencies] +smmap = ">=3.0.1,<6" + +[[package]] +name = "gitpython" +version = "3.1.43" +description = "GitPython is a Python library used to interact with Git repositories" +optional = false +python-versions = ">=3.7" +files = [ + {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"}, + {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"}, +] + +[package.dependencies] +gitdb = ">=4.0.1,<5" + +[package.extras] +doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"] +test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"] + +[[package]] +name = "h5py" +version = "3.12.1" +description = "Read and write HDF5 files from Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "h5py-3.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f0f1a382cbf494679c07b4371f90c70391dedb027d517ac94fa2c05299dacda"}, + {file = "h5py-3.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cb65f619dfbdd15e662423e8d257780f9a66677eae5b4b3fc9dca70b5fd2d2a3"}, + {file = "h5py-3.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b15d8dbd912c97541312c0e07438864d27dbca857c5ad634de68110c6beb1c2"}, + {file = "h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59685fe40d8c1fbbee088c88cd4da415a2f8bee5c270337dc5a1c4aa634e3307"}, + {file = "h5py-3.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:577d618d6b6dea3da07d13cc903ef9634cde5596b13e832476dd861aaf651f3e"}, + {file = "h5py-3.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ccd9006d92232727d23f784795191bfd02294a4f2ba68708825cb1da39511a93"}, + {file = "h5py-3.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad8a76557880aed5234cfe7279805f4ab5ce16b17954606cca90d578d3e713ef"}, + {file = "h5py-3.12.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1473348139b885393125126258ae2d70753ef7e9cec8e7848434f385ae72069e"}, + {file = "h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018a4597f35092ae3fb28ee851fdc756d2b88c96336b8480e124ce1ac6fb9166"}, + {file = "h5py-3.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fdf95092d60e8130ba6ae0ef7a9bd4ade8edbe3569c13ebbaf39baefffc5ba4"}, + {file = "h5py-3.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:06a903a4e4e9e3ebbc8b548959c3c2552ca2d70dac14fcfa650d9261c66939ed"}, + {file = "h5py-3.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b3b8f3b48717e46c6a790e3128d39c61ab595ae0a7237f06dfad6a3b51d5351"}, + {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834"}, + {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9"}, + {file = "h5py-3.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4d51919110a030913201422fb07987db4338eba5ec8c5a15d6fab8e03d443fc"}, + {file = "h5py-3.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:513171e90ed92236fc2ca363ce7a2fc6f2827375efcbb0cc7fbdd7fe11fecafc"}, + {file = "h5py-3.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:59400f88343b79655a242068a9c900001a34b63e3afb040bd7cdf717e440f653"}, + {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e465aee0ec353949f0f46bf6c6f9790a2006af896cee7c178a8c3e5090aa32"}, + {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba51c0c5e029bb5420a343586ff79d56e7455d496d18a30309616fdbeed1068f"}, + {file = "h5py-3.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:52ab036c6c97055b85b2a242cb540ff9590bacfda0c03dd0cf0661b311f522f8"}, + {file = "h5py-3.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2b8dd64f127d8b324f5d2cd1c0fd6f68af69084e9e47d27efeb9e28e685af3e"}, + {file = "h5py-3.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4532c7e97fbef3d029735db8b6f5bf01222d9ece41e309b20d63cfaae2fb5c4d"}, + {file = "h5py-3.12.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fdf6d7936fa824acfa27305fe2d9f39968e539d831c5bae0e0d83ed521ad1ac"}, + {file = "h5py-3.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84342bffd1f82d4f036433e7039e241a243531a1d3acd7341b35ae58cdab05bf"}, + {file = "h5py-3.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:62be1fc0ef195891949b2c627ec06bc8e837ff62d5b911b6e42e38e0f20a897d"}, + {file = "h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf"}, +] + +[package.dependencies] +numpy = ">=1.19.3" + +[[package]] +name = "identify" +version = "2.6.3" +description = "File identification library for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "identify-2.6.3-py2.py3-none-any.whl", hash = "sha256:9edba65473324c2ea9684b1f944fe3191db3345e50b6d04571d10ed164f8d7bd"}, + {file = "identify-2.6.3.tar.gz", hash = "sha256:62f5dae9b5fef52c84cc188514e9ea4f3f636b1d8799ab5ebc475471f9e47a02"}, +] + +[package.extras] +license = ["ukkonen"] + +[[package]] +name = "idna" +version = "3.10" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.6" +files = [ + {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, + {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, +] + +[package.extras] +all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] + +[[package]] +name = "importlib-metadata" +version = "8.5.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b"}, + {file = "importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7"}, +] + +[package.dependencies] +zipp = ">=3.20" + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +perf = ["ipython"] +test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] +type = ["pytest-mypy"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + +[[package]] +name = "kiwisolver" +version = "1.4.7" +description = "A fast implementation of the Cassowary constraint solver" +optional = false +python-versions = ">=3.8" +files = [ + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win32.whl", hash = "sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_amd64.whl", hash = "sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_arm64.whl", hash = "sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win32.whl", hash = "sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_amd64.whl", hash = "sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_arm64.whl", hash = "sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win32.whl", hash = "sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win32.whl", hash = "sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_amd64.whl", hash = "sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_arm64.whl", hash = "sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win32.whl", hash = "sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win_amd64.whl", hash = "sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win32.whl", hash = "sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_amd64.whl", hash = "sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_arm64.whl", hash = "sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0"}, + {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, +] + +[[package]] +name = "loguru" +version = "0.7.2" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] + +[[package]] +name = "matplotlib" +version = "3.9.3" +description = "Python plotting package" +optional = false +python-versions = ">=3.9" +files = [ + {file = "matplotlib-3.9.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:41b016e3be4e740b66c79a031a0a6e145728dbc248142e751e8dab4f3188ca1d"}, + {file = "matplotlib-3.9.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e0143975fc2a6d7136c97e19c637321288371e8f09cff2564ecd73e865ea0b9"}, + {file = "matplotlib-3.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f459c8ee2c086455744723628264e43c884be0c7d7b45d84b8cd981310b4815"}, + {file = "matplotlib-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:687df7ceff57b8f070d02b4db66f75566370e7ae182a0782b6d3d21b0d6917dc"}, + {file = "matplotlib-3.9.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:edd14cf733fdc4f6e6fe3f705af97676a7e52859bf0044aa2c84e55be739241c"}, + {file = "matplotlib-3.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:1c40c244221a1adbb1256692b1133c6fb89418df27bf759a31a333e7912a4010"}, + {file = "matplotlib-3.9.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:cf2a60daf6cecff6828bc608df00dbc794380e7234d2411c0ec612811f01969d"}, + {file = "matplotlib-3.9.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:213d6dc25ce686516208d8a3e91120c6a4fdae4a3e06b8505ced5b716b50cc04"}, + {file = "matplotlib-3.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c52f48eb75fcc119a4fdb68ba83eb5f71656999420375df7c94cc68e0e14686e"}, + {file = "matplotlib-3.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3c93796b44fa111049b88a24105e947f03c01966b5c0cc782e2ee3887b790a3"}, + {file = "matplotlib-3.9.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cd1077b9a09b16d8c3c7075a8add5ffbfe6a69156a57e290c800ed4d435bef1d"}, + {file = "matplotlib-3.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:c96eeeb8c68b662c7747f91a385688d4b449687d29b691eff7068a4602fe6dc4"}, + {file = "matplotlib-3.9.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0a361bd5583bf0bcc08841df3c10269617ee2a36b99ac39d455a767da908bbbc"}, + {file = "matplotlib-3.9.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e14485bb1b83eeb3d55b6878f9560240981e7bbc7a8d4e1e8c38b9bd6ec8d2de"}, + {file = "matplotlib-3.9.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a8d279f78844aad213c4935c18f8292a9432d51af2d88bca99072c903948045"}, + {file = "matplotlib-3.9.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6c12514329ac0d03128cf1dcceb335f4fbf7c11da98bca68dca8dcb983153a9"}, + {file = "matplotlib-3.9.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6e9de2b390d253a508dd497e9b5579f3a851f208763ed67fdca5dc0c3ea6849c"}, + {file = "matplotlib-3.9.3-cp312-cp312-win_amd64.whl", hash = "sha256:d796272408f8567ff7eaa00eb2856b3a00524490e47ad505b0b4ca6bb8a7411f"}, + {file = "matplotlib-3.9.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:203d18df84f5288973b2d56de63d4678cc748250026ca9e1ad8f8a0fd8a75d83"}, + {file = "matplotlib-3.9.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b651b0d3642991259109dc0351fc33ad44c624801367bb8307be9bfc35e427ad"}, + {file = "matplotlib-3.9.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66d7b171fecf96940ce069923a08ba3df33ef542de82c2ff4fe8caa8346fa95a"}, + {file = "matplotlib-3.9.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6be0ba61f6ff2e6b68e4270fb63b6813c9e7dec3d15fc3a93f47480444fd72f0"}, + {file = "matplotlib-3.9.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9d6b2e8856dec3a6db1ae51aec85c82223e834b228c1d3228aede87eee2b34f9"}, + {file = "matplotlib-3.9.3-cp313-cp313-win_amd64.whl", hash = "sha256:90a85a004fefed9e583597478420bf904bb1a065b0b0ee5b9d8d31b04b0f3f70"}, + {file = "matplotlib-3.9.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3119b2f16de7f7b9212ba76d8fe6a0e9f90b27a1e04683cd89833a991682f639"}, + {file = "matplotlib-3.9.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:87ad73763d93add1b6c1f9fcd33af662fd62ed70e620c52fcb79f3ac427cf3a6"}, + {file = "matplotlib-3.9.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:026bdf3137ab6022c866efa4813b6bbeddc2ed4c9e7e02f0e323a7bca380dfa0"}, + {file = "matplotlib-3.9.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760a5e89ebbb172989e8273024a1024b0f084510b9105261b3b00c15e9c9f006"}, + {file = "matplotlib-3.9.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a42b9dc42de2cfe357efa27d9c50c7833fc5ab9b2eb7252ccd5d5f836a84e1e4"}, + {file = "matplotlib-3.9.3-cp313-cp313t-win_amd64.whl", hash = "sha256:e0fcb7da73fbf67b5f4bdaa57d85bb585a4e913d4a10f3e15b32baea56a67f0a"}, + {file = "matplotlib-3.9.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:031b7f5b8e595cc07def77ec5b58464e9bb67dc5760be5d6f26d9da24892481d"}, + {file = "matplotlib-3.9.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9fa6e193c14d6944e0685cdb527cb6b38b0e4a518043e7212f214113af7391da"}, + {file = "matplotlib-3.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e6eefae6effa0c35bbbc18c25ee6e0b1da44d2359c3cd526eb0c9e703cf055d"}, + {file = "matplotlib-3.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d3e5c7a99bd28afb957e1ae661323b0800d75b419f24d041ed1cc5d844a764"}, + {file = "matplotlib-3.9.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:816a966d5d376bf24c92af8f379e78e67278833e4c7cbc9fa41872eec629a060"}, + {file = "matplotlib-3.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fb0b37c896172899a4a93d9442ffdc6f870165f59e05ce2e07c6fded1c15749"}, + {file = "matplotlib-3.9.3-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5f2a4ea08e6876206d511365b0bc234edc813d90b930be72c3011bbd7898796f"}, + {file = "matplotlib-3.9.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:9b081dac96ab19c54fd8558fac17c9d2c9cb5cc4656e7ed3261ddc927ba3e2c5"}, + {file = "matplotlib-3.9.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a0a63cb8404d1d1f94968ef35738900038137dab8af836b6c21bb6f03d75465"}, + {file = "matplotlib-3.9.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:896774766fd6be4571a43bc2fcbcb1dcca0807e53cab4a5bf88c4aa861a08e12"}, + {file = "matplotlib-3.9.3.tar.gz", hash = "sha256:cd5dbbc8e25cad5f706845c4d100e2c8b34691b412b93717ce38d8ae803bcfa5"}, +] + +[package.dependencies] +contourpy = ">=1.0.1" +cycler = ">=0.10" +fonttools = ">=4.22.0" +kiwisolver = ">=1.3.1" +numpy = ">=1.23" +packaging = ">=20.0" +pillow = ">=8" +pyparsing = ">=2.3.1" +python-dateutil = ">=2.7" + +[package.extras] +dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6,!=2.13.3)", "setuptools (>=64)", "setuptools_scm (>=7)"] + +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +name = "natsort" +version = "8.4.0" +description = "Simple yet flexible natural sorting in Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, + {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, +] + +[package.extras] +fast = ["fastnumbers (>=2.0.0)"] +icu = ["PyICU (>=1.0.0)"] + +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + +[[package]] +name = "packaging" +version = "24.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, +] + +[[package]] +name = "pandas" +version = "2.2.2" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, + {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, + {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, + {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.7" + +[package.extras] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] + +[[package]] +name = "pandas-stubs" +version = "2.2.3.241126" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +files = [ + {file = "pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267"}, + {file = "pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + +[[package]] +name = "pillow" +version = "11.0.0" +description = "Python Imaging Library (Fork)" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pillow-11.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947"}, + {file = "pillow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97"}, + {file = "pillow-11.0.0-cp310-cp310-win32.whl", hash = "sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50"}, + {file = "pillow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c"}, + {file = "pillow-11.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9"}, + {file = "pillow-11.0.0-cp311-cp311-win32.whl", hash = "sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5"}, + {file = "pillow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291"}, + {file = "pillow-11.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc"}, + {file = "pillow-11.0.0-cp312-cp312-win32.whl", hash = "sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6"}, + {file = "pillow-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47"}, + {file = "pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb"}, + {file = "pillow-11.0.0-cp313-cp313-win32.whl", hash = "sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798"}, + {file = "pillow-11.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de"}, + {file = "pillow-11.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a"}, + {file = "pillow-11.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8"}, + {file = "pillow-11.0.0-cp313-cp313t-win32.whl", hash = "sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8"}, + {file = "pillow-11.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904"}, + {file = "pillow-11.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae"}, + {file = "pillow-11.0.0-cp39-cp39-win32.whl", hash = "sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4"}, + {file = "pillow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd"}, + {file = "pillow-11.0.0-cp39-cp39-win_arm64.whl", hash = "sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944"}, + {file = "pillow-11.0.0.tar.gz", hash = "sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +typing = ["typing-extensions"] +xmp = ["defusedxml"] + +[[package]] +name = "pip" +version = "24.3.1" +description = "The PyPA recommended tool for installing Python packages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pip-24.3.1-py3-none-any.whl", hash = "sha256:3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed"}, + {file = "pip-24.3.1.tar.gz", hash = "sha256:ebcb60557f2aefabc2e0f918751cd24ea0d56d8ec5445fe1807f1d2109660b99"}, +] + +[[package]] +name = "pip-tools" +version = "7.4.1" +description = "pip-tools keeps your pinned dependencies fresh." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pip-tools-7.4.1.tar.gz", hash = "sha256:864826f5073864450e24dbeeb85ce3920cdfb09848a3d69ebf537b521f14bcc9"}, + {file = "pip_tools-7.4.1-py3-none-any.whl", hash = "sha256:4c690e5fbae2f21e87843e89c26191f0d9454f362d8acdbd695716493ec8b3a9"}, +] + +[package.dependencies] +build = ">=1.0.0" +click = ">=8" +pip = ">=22.2" +pyproject_hooks = "*" +setuptools = "*" +tomli = {version = "*", markers = "python_version < \"3.11\""} +wheel = "*" + +[package.extras] +coverage = ["covdefaults", "pytest-cov"] +testing = ["flit_core (>=2,<4)", "poetry_core (>=1.0.0)", "pytest (>=7.2.0)", "pytest-rerunfailures", "pytest-xdist", "tomli-w"] + +[[package]] +name = "platformdirs" +version = "4.3.6" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +optional = false +python-versions = ">=3.8" +files = [ + {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, + {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] + +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pre-commit" +version = "3.8.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f"}, + {file = "pre_commit-3.8.0.tar.gz", hash = "sha256:8bb6494d4a20423842e198980c9ecf9f96607a07ea29549e180eef9ae80fe7af"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + +[[package]] +name = "pyarrow" +version = "15.0.2" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"}, + {file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"}, + {file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"}, + {file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"}, + {file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"}, + {file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"}, + {file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"}, + {file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" + +[[package]] +name = "pydantic" +version = "2.10.2" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.10.2-py3-none-any.whl", hash = "sha256:cfb96e45951117c3024e6b67b25cdc33a3cb7b2fa62e239f7af1378358a1d99e"}, + {file = "pydantic-2.10.2.tar.gz", hash = "sha256:2bc2d7f17232e0841cbba4641e65ba1eb6fafb3a08de3a091ff3ce14a197c4fa"}, +] + +[package.dependencies] +annotated-types = ">=0.6.0" +pydantic-core = "2.27.1" +typing-extensions = ">=4.12.2" + +[package.extras] +email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata"] + +[[package]] +name = "pydantic-core" +version = "2.27.1" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.27.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:71a5e35c75c021aaf400ac048dacc855f000bdfed91614b4a726f7432f1f3d6a"}, + {file = "pydantic_core-2.27.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f82d068a2d6ecfc6e054726080af69a6764a10015467d7d7b9f66d6ed5afa23b"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:121ceb0e822f79163dd4699e4c54f5ad38b157084d97b34de8b232bcaad70278"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4603137322c18eaf2e06a4495f426aa8d8388940f3c457e7548145011bb68e05"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a33cd6ad9017bbeaa9ed78a2e0752c5e250eafb9534f308e7a5f7849b0b1bfb4"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15cc53a3179ba0fcefe1e3ae50beb2784dede4003ad2dfd24f81bba4b23a454f"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45d9c5eb9273aa50999ad6adc6be5e0ecea7e09dbd0d31bd0c65a55a2592ca08"}, + {file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8bf7b66ce12a2ac52d16f776b31d16d91033150266eb796967a7e4621707e4f6"}, + {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:655d7dd86f26cb15ce8a431036f66ce0318648f8853d709b4167786ec2fa4807"}, + {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:5556470f1a2157031e676f776c2bc20acd34c1990ca5f7e56f1ebf938b9ab57c"}, + {file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f69ed81ab24d5a3bd93861c8c4436f54afdf8e8cc421562b0c7504cf3be58206"}, + {file = "pydantic_core-2.27.1-cp310-none-win32.whl", hash = "sha256:f5a823165e6d04ccea61a9f0576f345f8ce40ed533013580e087bd4d7442b52c"}, + {file = "pydantic_core-2.27.1-cp310-none-win_amd64.whl", hash = "sha256:57866a76e0b3823e0b56692d1a0bf722bffb324839bb5b7226a7dbd6c9a40b17"}, + {file = "pydantic_core-2.27.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac3b20653bdbe160febbea8aa6c079d3df19310d50ac314911ed8cc4eb7f8cb8"}, + {file = "pydantic_core-2.27.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a5a8e19d7c707c4cadb8c18f5f60c843052ae83c20fa7d44f41594c644a1d330"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f7059ca8d64fea7f238994c97d91f75965216bcbe5f695bb44f354893f11d52"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed0f8a0eeea9fb72937ba118f9db0cb7e90773462af7962d382445f3005e5a4"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3cb37038123447cf0f3ea4c74751f6a9d7afef0eb71aa07bf5f652b5e6a132c"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84286494f6c5d05243456e04223d5a9417d7f443c3b76065e75001beb26f88de"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acc07b2cfc5b835444b44a9956846b578d27beeacd4b52e45489e93276241025"}, + {file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fefee876e07a6e9aad7a8c8c9f85b0cdbe7df52b8a9552307b09050f7512c7e"}, + {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:258c57abf1188926c774a4c94dd29237e77eda19462e5bb901d88adcab6af919"}, + {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:35c14ac45fcfdf7167ca76cc80b2001205a8d5d16d80524e13508371fb8cdd9c"}, + {file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d1b26e1dff225c31897696cab7d4f0a315d4c0d9e8666dbffdb28216f3b17fdc"}, + {file = "pydantic_core-2.27.1-cp311-none-win32.whl", hash = "sha256:2cdf7d86886bc6982354862204ae3b2f7f96f21a3eb0ba5ca0ac42c7b38598b9"}, + {file = "pydantic_core-2.27.1-cp311-none-win_amd64.whl", hash = "sha256:3af385b0cee8df3746c3f406f38bcbfdc9041b5c2d5ce3e5fc6637256e60bbc5"}, + {file = "pydantic_core-2.27.1-cp311-none-win_arm64.whl", hash = "sha256:81f2ec23ddc1b476ff96563f2e8d723830b06dceae348ce02914a37cb4e74b89"}, + {file = "pydantic_core-2.27.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9cbd94fc661d2bab2bc702cddd2d3370bbdcc4cd0f8f57488a81bcce90c7a54f"}, + {file = "pydantic_core-2.27.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5f8c4718cd44ec1580e180cb739713ecda2bdee1341084c1467802a417fe0f02"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15aae984e46de8d376df515f00450d1522077254ef6b7ce189b38ecee7c9677c"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ba5e3963344ff25fc8c40da90f44b0afca8cfd89d12964feb79ac1411a260ac"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:992cea5f4f3b29d6b4f7f1726ed8ee46c8331c6b4eed6db5b40134c6fe1768bb"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0325336f348dbee6550d129b1627cb8f5351a9dc91aad141ffb96d4937bd9529"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7597c07fbd11515f654d6ece3d0e4e5093edc30a436c63142d9a4b8e22f19c35"}, + {file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3bbd5d8cc692616d5ef6fbbbd50dbec142c7e6ad9beb66b78a96e9c16729b089"}, + {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:dc61505e73298a84a2f317255fcc72b710b72980f3a1f670447a21efc88f8381"}, + {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:e1f735dc43da318cad19b4173dd1ffce1d84aafd6c9b782b3abc04a0d5a6f5bb"}, + {file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f4e5658dbffe8843a0f12366a4c2d1c316dbe09bb4dfbdc9d2d9cd6031de8aae"}, + {file = "pydantic_core-2.27.1-cp312-none-win32.whl", hash = "sha256:672ebbe820bb37988c4d136eca2652ee114992d5d41c7e4858cdd90ea94ffe5c"}, + {file = "pydantic_core-2.27.1-cp312-none-win_amd64.whl", hash = "sha256:66ff044fd0bb1768688aecbe28b6190f6e799349221fb0de0e6f4048eca14c16"}, + {file = "pydantic_core-2.27.1-cp312-none-win_arm64.whl", hash = "sha256:9a3b0793b1bbfd4146304e23d90045f2a9b5fd5823aa682665fbdaf2a6c28f3e"}, + {file = "pydantic_core-2.27.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f216dbce0e60e4d03e0c4353c7023b202d95cbaeff12e5fd2e82ea0a66905073"}, + {file = "pydantic_core-2.27.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a2e02889071850bbfd36b56fd6bc98945e23670773bc7a76657e90e6b6603c08"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b0e23f119b2b456d07ca91b307ae167cc3f6c846a7b169fca5326e32fdc6cf"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:764be71193f87d460a03f1f7385a82e226639732214b402f9aa61f0d025f0737"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c00666a3bd2f84920a4e94434f5974d7bbc57e461318d6bb34ce9cdbbc1f6b2"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ccaa88b24eebc0f849ce0a4d09e8a408ec5a94afff395eb69baf868f5183107"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c65af9088ac534313e1963443d0ec360bb2b9cba6c2909478d22c2e363d98a51"}, + {file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:206b5cf6f0c513baffaeae7bd817717140770c74528f3e4c3e1cec7871ddd61a"}, + {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:062f60e512fc7fff8b8a9d680ff0ddaaef0193dba9fa83e679c0c5f5fbd018bc"}, + {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:a0697803ed7d4af5e4c1adf1670af078f8fcab7a86350e969f454daf598c4960"}, + {file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:58ca98a950171f3151c603aeea9303ef6c235f692fe555e883591103da709b23"}, + {file = "pydantic_core-2.27.1-cp313-none-win32.whl", hash = "sha256:8065914ff79f7eab1599bd80406681f0ad08f8e47c880f17b416c9f8f7a26d05"}, + {file = "pydantic_core-2.27.1-cp313-none-win_amd64.whl", hash = "sha256:ba630d5e3db74c79300d9a5bdaaf6200172b107f263c98a0539eeecb857b2337"}, + {file = "pydantic_core-2.27.1-cp313-none-win_arm64.whl", hash = "sha256:45cf8588c066860b623cd11c4ba687f8d7175d5f7ef65f7129df8a394c502de5"}, + {file = "pydantic_core-2.27.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:5897bec80a09b4084aee23f9b73a9477a46c3304ad1d2d07acca19723fb1de62"}, + {file = "pydantic_core-2.27.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0165ab2914379bd56908c02294ed8405c252250668ebcb438a55494c69f44ab"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b9af86e1d8e4cfc82c2022bfaa6f459381a50b94a29e95dcdda8442d6d83864"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f6c8a66741c5f5447e047ab0ba7a1c61d1e95580d64bce852e3df1f895c4067"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a42d6a8156ff78981f8aa56eb6394114e0dedb217cf8b729f438f643608cbcd"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64c65f40b4cd8b0e049a8edde07e38b476da7e3aaebe63287c899d2cff253fa5"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdcf339322a3fae5cbd504edcefddd5a50d9ee00d968696846f089b4432cf78"}, + {file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bf99c8404f008750c846cb4ac4667b798a9f7de673ff719d705d9b2d6de49c5f"}, + {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8f1edcea27918d748c7e5e4d917297b2a0ab80cad10f86631e488b7cddf76a36"}, + {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_armv7l.whl", hash = "sha256:159cac0a3d096f79ab6a44d77a961917219707e2a130739c64d4dd46281f5c2a"}, + {file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:029d9757eb621cc6e1848fa0b0310310de7301057f623985698ed7ebb014391b"}, + {file = "pydantic_core-2.27.1-cp38-none-win32.whl", hash = "sha256:a28af0695a45f7060e6f9b7092558a928a28553366519f64083c63a44f70e618"}, + {file = "pydantic_core-2.27.1-cp38-none-win_amd64.whl", hash = "sha256:2d4567c850905d5eaaed2f7a404e61012a51caf288292e016360aa2b96ff38d4"}, + {file = "pydantic_core-2.27.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e9386266798d64eeb19dd3677051f5705bf873e98e15897ddb7d76f477131967"}, + {file = "pydantic_core-2.27.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4228b5b646caa73f119b1ae756216b59cc6e2267201c27d3912b592c5e323b60"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b3dfe500de26c52abe0477dde16192ac39c98f05bf2d80e76102d394bd13854"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aee66be87825cdf72ac64cb03ad4c15ffef4143dbf5c113f64a5ff4f81477bf9"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b748c44bb9f53031c8cbc99a8a061bc181c1000c60a30f55393b6e9c45cc5bd"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ca038c7f6a0afd0b2448941b6ef9d5e1949e999f9e5517692eb6da58e9d44be"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bd57539da59a3e4671b90a502da9a28c72322a4f17866ba3ac63a82c4498e"}, + {file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ac6c2c45c847bbf8f91930d88716a0fb924b51e0c6dad329b793d670ec5db792"}, + {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b94d4ba43739bbe8b0ce4262bcc3b7b9f31459ad120fb595627eaeb7f9b9ca01"}, + {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:00e6424f4b26fe82d44577b4c842d7df97c20be6439e8e685d0d715feceb9fb9"}, + {file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:38de0a70160dd97540335b7ad3a74571b24f1dc3ed33f815f0880682e6880131"}, + {file = "pydantic_core-2.27.1-cp39-none-win32.whl", hash = "sha256:7ccebf51efc61634f6c2344da73e366c75e735960b5654b63d7e6f69a5885fa3"}, + {file = "pydantic_core-2.27.1-cp39-none-win_amd64.whl", hash = "sha256:a57847b090d7892f123726202b7daa20df6694cbd583b67a592e856bff603d6c"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3fa80ac2bd5856580e242dbc202db873c60a01b20309c8319b5c5986fbe53ce6"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d950caa237bb1954f1b8c9227b5065ba6875ac9771bb8ec790d956a699b78676"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e4216e64d203e39c62df627aa882f02a2438d18a5f21d7f721621f7a5d3611d"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02a3d637bd387c41d46b002f0e49c52642281edacd2740e5a42f7017feea3f2c"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:161c27ccce13b6b0c8689418da3885d3220ed2eae2ea5e9b2f7f3d48f1d52c27"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19910754e4cc9c63bc1c7f6d73aa1cfee82f42007e407c0f413695c2f7ed777f"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:e173486019cc283dc9778315fa29a363579372fe67045e971e89b6365cc035ed"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:af52d26579b308921b73b956153066481f064875140ccd1dfd4e77db89dbb12f"}, + {file = "pydantic_core-2.27.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:981fb88516bd1ae8b0cbbd2034678a39dedc98752f264ac9bc5839d3923fa04c"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5fde892e6c697ce3e30c61b239330fc5d569a71fefd4eb6512fc6caec9dd9e2f"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:816f5aa087094099fff7edabb5e01cc370eb21aa1a1d44fe2d2aefdfb5599b31"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c10c309e18e443ddb108f0ef64e8729363adbfd92d6d57beec680f6261556f3"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98476c98b02c8e9b2eec76ac4156fd006628b1b2d0ef27e548ffa978393fd154"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c3027001c28434e7ca5a6e1e527487051136aa81803ac812be51802150d880dd"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7699b1df36a48169cdebda7ab5a2bac265204003f153b4bd17276153d997670a"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1c39b07d90be6b48968ddc8c19e7585052088fd7ec8d568bb31ff64c70ae3c97"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:46ccfe3032b3915586e469d4972973f893c0a2bb65669194a5bdea9bacc088c2"}, + {file = "pydantic_core-2.27.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:62ba45e21cf6571d7f716d903b5b7b6d2617e2d5d67c0923dc47b9d41369f840"}, + {file = "pydantic_core-2.27.1.tar.gz", hash = "sha256:62a763352879b84aa31058fc931884055fd75089cccbd9d58bb6afd01141b235"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pydeseq2" +version = "0.4.9" +description = "A python implementation of DESeq2." +optional = false +python-versions = ">=3.9.0" +files = [ + {file = "pydeseq2-0.4.9-py3-none-any.whl", hash = "sha256:7f112fe1dfd3cef1c19e1ead67379d348f2517ad0594fac0fcbae847d0d62020"}, + {file = "pydeseq2-0.4.9.tar.gz", hash = "sha256:0375207775953f43f84ed4279fcb9f11a430d79d038e158bbe74acbc58326d31"}, +] + +[package.dependencies] +anndata = ">=0.8.0" +matplotlib = ">=3.6.2" +numpy = ">=1.23.0" +pandas = ">=1.4.0" +scikit-learn = ">=1.1.0" +scipy = ">=1.11.0" + +[package.extras] +dev = ["coverage", "mypy", "numpydoc", "pandas-stubs", "pre-commit (>=2.13.0)", "pytest (>=6.2.4)"] + +[[package]] +name = "pyparsing" +version = "3.2.0" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, + {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + +[[package]] +name = "pyproject-hooks" +version = "1.2.0" +description = "Wrappers to call pyproject.toml-based build backend hooks." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913"}, + {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"}, +] + +[[package]] +name = "pytest" +version = "8.3.4" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, + {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "python-slugify" +version = "8.0.4" +description = "A Python slugify application that also handles Unicode" +optional = false +python-versions = ">=3.7" +files = [ + {file = "python-slugify-8.0.4.tar.gz", hash = "sha256:59202371d1d05b54a9e7720c5e038f928f45daaffe41dd10822f3907b937c856"}, + {file = "python_slugify-8.0.4-py2.py3-none-any.whl", hash = "sha256:276540b79961052b66b7d116620b36518847f52d5fd9e3a70164fc8c50faa6b8"}, +] + +[package.dependencies] +text-unidecode = ">=1.3" + +[package.extras] +unidecode = ["Unidecode (>=1.1.1)"] + +[[package]] +name = "pytz" +version = "2024.2" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, + {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, +] + +[[package]] +name = "pywin32" +version = "308" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-308-cp310-cp310-win32.whl", hash = "sha256:796ff4426437896550d2981b9c2ac0ffd75238ad9ea2d3bfa67a1abd546d262e"}, + {file = "pywin32-308-cp310-cp310-win_amd64.whl", hash = "sha256:4fc888c59b3c0bef905ce7eb7e2106a07712015ea1c8234b703a088d46110e8e"}, + {file = "pywin32-308-cp310-cp310-win_arm64.whl", hash = "sha256:a5ab5381813b40f264fa3495b98af850098f814a25a63589a8e9eb12560f450c"}, + {file = "pywin32-308-cp311-cp311-win32.whl", hash = "sha256:5d8c8015b24a7d6855b1550d8e660d8daa09983c80e5daf89a273e5c6fb5095a"}, + {file = "pywin32-308-cp311-cp311-win_amd64.whl", hash = "sha256:575621b90f0dc2695fec346b2d6302faebd4f0f45c05ea29404cefe35d89442b"}, + {file = "pywin32-308-cp311-cp311-win_arm64.whl", hash = "sha256:100a5442b7332070983c4cd03f2e906a5648a5104b8a7f50175f7906efd16bb6"}, + {file = "pywin32-308-cp312-cp312-win32.whl", hash = "sha256:587f3e19696f4bf96fde9d8a57cec74a57021ad5f204c9e627e15c33ff568897"}, + {file = "pywin32-308-cp312-cp312-win_amd64.whl", hash = "sha256:00b3e11ef09ede56c6a43c71f2d31857cf7c54b0ab6e78ac659497abd2834f47"}, + {file = "pywin32-308-cp312-cp312-win_arm64.whl", hash = "sha256:9b4de86c8d909aed15b7011182c8cab38c8850de36e6afb1f0db22b8959e3091"}, + {file = "pywin32-308-cp313-cp313-win32.whl", hash = "sha256:1c44539a37a5b7b21d02ab34e6a4d314e0788f1690d65b48e9b0b89f31abbbed"}, + {file = "pywin32-308-cp313-cp313-win_amd64.whl", hash = "sha256:fd380990e792eaf6827fcb7e187b2b4b1cede0585e3d0c9e84201ec27b9905e4"}, + {file = "pywin32-308-cp313-cp313-win_arm64.whl", hash = "sha256:ef313c46d4c18dfb82a2431e3051ac8f112ccee1a34f29c263c583c568db63cd"}, + {file = "pywin32-308-cp37-cp37m-win32.whl", hash = "sha256:1f696ab352a2ddd63bd07430080dd598e6369152ea13a25ebcdd2f503a38f1ff"}, + {file = "pywin32-308-cp37-cp37m-win_amd64.whl", hash = "sha256:13dcb914ed4347019fbec6697a01a0aec61019c1046c2b905410d197856326a6"}, + {file = "pywin32-308-cp38-cp38-win32.whl", hash = "sha256:5794e764ebcabf4ff08c555b31bd348c9025929371763b2183172ff4708152f0"}, + {file = "pywin32-308-cp38-cp38-win_amd64.whl", hash = "sha256:3b92622e29d651c6b783e368ba7d6722b1634b8e70bd376fd7610fe1992e19de"}, + {file = "pywin32-308-cp39-cp39-win32.whl", hash = "sha256:7873ca4dc60ab3287919881a7d4f88baee4a6e639aa6962de25a98ba6b193341"}, + {file = "pywin32-308-cp39-cp39-win_amd64.whl", hash = "sha256:71b3322d949b4cc20776436a9c9ba0eeedcbc9c650daa536df63f0ff111bb920"}, +] + +[[package]] +name = "pyyaml" +version = "6.0.2" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, +] + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.7" +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "ruff" +version = "0.2.2" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0a9efb032855ffb3c21f6405751d5e147b0c6b631e3ca3f6b20f917572b97eb6"}, + {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d450b7fbff85913f866a5384d8912710936e2b96da74541c82c1b458472ddb39"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecd46e3106850a5c26aee114e562c329f9a1fbe9e4821b008c4404f64ff9ce73"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e22676a5b875bd72acd3d11d5fa9075d3a5f53b877fe7b4793e4673499318ba"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1695700d1e25a99d28f7a1636d85bafcc5030bba9d0578c0781ba1790dbcf51c"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:b0c232af3d0bd8f521806223723456ffebf8e323bd1e4e82b0befb20ba18388e"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f63d96494eeec2fc70d909393bcd76c69f35334cdbd9e20d089fb3f0640216ca"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a61ea0ff048e06de273b2e45bd72629f470f5da8f71daf09fe481278b175001"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1439c8f407e4f356470e54cdecdca1bd5439a0673792dbe34a2b0a551a2fe3"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:940de32dc8853eba0f67f7198b3e79bc6ba95c2edbfdfac2144c8235114d6726"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0c126da55c38dd917621552ab430213bdb3273bb10ddb67bc4b761989210eb6e"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3b65494f7e4bed2e74110dac1f0d17dc8e1f42faaa784e7c58a98e335ec83d7e"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1ec49be4fe6ddac0503833f3ed8930528e26d1e60ad35c2446da372d16651ce9"}, + {file = "ruff-0.2.2-py3-none-win32.whl", hash = "sha256:d920499b576f6c68295bc04e7b17b6544d9d05f196bb3aac4358792ef6f34325"}, + {file = "ruff-0.2.2-py3-none-win_amd64.whl", hash = "sha256:cc9a91ae137d687f43a44c900e5d95e9617cb37d4c989e462980ba27039d239d"}, + {file = "ruff-0.2.2-py3-none-win_arm64.whl", hash = "sha256:c9d15fc41e6054bfc7200478720570078f0b41c9ae4f010bcc16bd6f4d1aacdd"}, + {file = "ruff-0.2.2.tar.gz", hash = "sha256:e62ed7f36b3068a30ba39193a14274cd706bc486fad521276458022f7bccb31d"}, +] + +[[package]] +name = "scikit-learn" +version = "1.5.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scipy" +version = "1.14.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.10" +files = [ + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, +] + +[package.dependencies] +numpy = ">=1.23.5,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + +[[package]] +name = "setuptools" +version = "75.6.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.9" +files = [ + {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"}, + {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "smmap" +version = "5.0.1" +description = "A pure Python implementation of a sliding window memory map manager" +optional = false +python-versions = ">=3.7" +files = [ + {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"}, + {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, +] + +[[package]] +name = "substra" +version = "0.54.0" +description = "Low-level Python library for interacting with a Substra network" +optional = false +python-versions = ">=3.9" +files = [ + {file = "substra-0.54.0-py3-none-any.whl", hash = "sha256:82c17595612650158555886bf6ab7b749e6b9b19bd2a4c27d5b0adc8d4ea7c23"}, + {file = "substra-0.54.0.tar.gz", hash = "sha256:e817273848ef1704a874c6f3961ed586a2c550b407557524da3f0522acb6ef6d"}, +] + +[package.dependencies] +docker = "*" +pydantic = ">=2.3.0,<3.0.0" +python-slugify = "*" +pyyaml = "*" +requests = "!=2.32.*" +tqdm = "*" + +[package.extras] +dev = ["black", "docstring-parser", "flake8", "isort", "pandas", "pytest", "pytest-cov", "pytest-mock", "substratools (>=0.22.0,<0.23.0)", "towncrier"] + +[[package]] +name = "substrafl" +version = "0.47.0" +description = "A high-level federated learning Python library to run federated learning experiments at scale on a Substra network" +optional = false +python-versions = ">=3.9" +files = [ + {file = "substrafl-0.47.0-py3-none-any.whl", hash = "sha256:96a2e61a0f5231f0bcb318cd6473d2ae8021da483eec926ff65b429d98f3a533"}, + {file = "substrafl-0.47.0.tar.gz", hash = "sha256:c19760a669d649b15d1e4e13101a4f528c51ef2c86acf7bbba84b9d3e1a4bbe5"}, +] + +[package.dependencies] +cloudpickle = ">=1.6.0" +numpy = ">=1.24,<2.0" +packaging = "*" +pip = ">=21.2" +pip-tools = "*" +pydantic = ">=2.3.0,<3.0" +six = "*" +substra = ">=0.54.0,<0.55.0" +substratools = ">=0.22.0,<0.23.0" +tqdm = "*" +wheel = "*" + +[package.extras] +dev = ["docker", "nbmake (>=1.4.3)", "pre-commit (>=2.13.0)", "pytest (>=6.2.4)", "pytest-cov (>=2.12.0)", "pytest-mock", "torch (>=1.9.1,!=1.12.0)", "towncrier", "types-pyyaml (>=6.0.0)"] + +[[package]] +name = "substratools" +version = "0.22.0" +description = "Python tools to submit functions on the Substra platform" +optional = false +python-versions = ">=3.9" +files = [ + {file = "substratools-0.22.0-py3-none-any.whl", hash = "sha256:ec27ad659993b60e6a2c2fc4f942b8c203670067ba59fce10bc9118e24ac7835"}, + {file = "substratools-0.22.0.tar.gz", hash = "sha256:acc8b0f27ad296c722e021a55cb1c621fb2d61aeaa095ce877acc6efd3dfdc0a"}, +] + +[package.extras] +dev = ["flake8", "numpy", "pytest", "pytest-cov", "pytest-mock", "towncrier"] + +[[package]] +name = "text-unidecode" +version = "1.3" +description = "The most basic Text::Unidecode port" +optional = false +python-versions = "*" +files = [ + {file = "text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93"}, + {file = "text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8"}, +] + +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + +[[package]] +name = "tomli" +version = "2.2.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"}, + {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"}, + {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, +] + +[[package]] +name = "tqdm" +version = "4.67.1" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, + {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"] +discord = ["requests"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "types-pytz" +version = "2024.2.0.20241003" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pytz-2024.2.0.20241003.tar.gz", hash = "sha256:575dc38f385a922a212bac00a7d6d2e16e141132a3c955078f4a4fd13ed6cb44"}, + {file = "types_pytz-2024.2.0.20241003-py3-none-any.whl", hash = "sha256:3e22df1336c0c6ad1d29163c8fda82736909eb977281cb823c57f8bae07118b7"}, +] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + +[[package]] +name = "tzdata" +version = "2024.2" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, + {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, +] + +[[package]] +name = "urllib3" +version = "2.2.3" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.8" +files = [ + {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, + {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +h2 = ["h2 (>=4,<5)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "virtualenv" +version = "20.28.0" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.8" +files = [ + {file = "virtualenv-20.28.0-py3-none-any.whl", hash = "sha256:23eae1b4516ecd610481eda647f3a7c09aea295055337331bb4e6892ecce47b0"}, + {file = "virtualenv-20.28.0.tar.gz", hash = "sha256:2c9c3262bb8e7b87ea801d715fae4495e6032450c71d2309be9550e7364049aa"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + +[[package]] +name = "wheel" +version = "0.45.1" +description = "A built-package format for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248"}, + {file = "wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729"}, +] + +[package.extras] +test = ["pytest (>=6.0.0)", "setuptools (>=65)"] + +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + +[[package]] +name = "zipp" +version = "3.21.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.9" +files = [ + {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, + {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] + +[metadata] +lock-version = "2.0" +python-versions = ">=3.10,<3.13" +content-hash = "64d46b0e37c499bc6c4e41f14335c5f6cbfd33c7ee0c302bd128691fdcc869bb" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b2cf5a5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,99 @@ +[tool.poetry] +name = "fedpydeseq2" +version = "0.1.0" +description = "This package is a SubstraFL implementation FL of PyDESeq2." +authors = ["Boris MUZELLEC ", "Ulysse MARTEAU ", "Tanguy MARCHAND "] +readme = "README.md" +packages = [{include = "fedpydeseq2"}] + +[build-system] +requires = ["poetry-core>=1.0.0", "setuptools>=65.6.3"] +build-backend = "poetry.core.masonry.api" + + +[tool.poetry.dependencies] +python = ">=3.10,<3.13" +substrafl = "0.47.0" +numpy = "1.26.4" +pandas = "2.2.2" +pyarrow = "15.0.2" +gitpython = "3.1.43" +anndata = "0.10.8" +pydeseq2 = "0.4.9" +loguru = "0.7.2" +toml = "0.10.2" +pyyaml = ">=5.1" + +[tool.poetry.group.linting] +optional = true + +[tool.poetry.group.linting.dependencies] +ruff = "^0.2.2" +pre-commit = "^3.6.2" +mypy = "^1.8.0" +black = "^24.2.0" +pandas-stubs = "^2.2.0.240218" + +[tool.poetry.group.testing] +optional = true + +[tool.poetry.group.testing.dependencies] +pytest = "^8.0.1" +fedpydeseq2_datasets="^0.1.0" + +[tool.black] +line-length = 88 + +[tool.ruff] +target-version = "py311" +line-length = 88 +lint.select = [ + "F", # Errors detected by Pyflakes + "E", # Error detected by Pycodestyle + "W", # Warning detected by Pycodestyle + "I", # isort + "D", # pydocstyle + "B", # flake8-bugbear + "TID", # flake8-tidy-imports + "C4", # flake8-comprehensions + "BLE", # flake8-blind-except + "UP", # pyupstage + "RUF100", # Report unused noqa directives + "D401", # Start docstrgins with an imperative verb + "D415", # End docstrings with a period + "D417", # Missing argument descriptions in the docstring +] + +lint.ignore = [ + # Missing docstring in public package + "D104", + # Missing docstring in public module + "D100", + # Missing docstring in __init__ + "D107", + # We don’t want a blank line before a class docstring + "D203", + # We want docstrings to start immediately after the opening triple quote + "D213", +] + +[tool.ruff.lint.isort] +force-single-line = true + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["D"] +# Ignore unused imports in __init__.py files +"*/__init__.py" = ["F401", "I"] + + +[tool.pytest.ini_options] +markers = [ + "self_hosted_slow: mark a test as a slow test with data", + "self_hosted_fast: mark a test as a fast test with data", + "docker: mark a test as a docker test", + "local: mark a test as a local test", + "dev: mark a test as a dev test", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c5a8ce7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,64 @@ +import json +from pathlib import Path + +import fedpydeseq2_datasets +import pytest + + +@pytest.fixture(scope="session") +def raw_data_path(): + """Fixture to get the path to the raw data.""" + default_paths = Path(__file__).parent / "paths_default.json" + specified_paths = Path(__file__).parent / "paths.json" + if specified_paths.exists(): + with open(specified_paths) as f: + raw_data_path = json.load(f)["raw_data"] + else: + with open(default_paths) as f: + raw_data_path = json.load(f)["raw_data"] + if raw_data_path.startswith("/"): + raw_data_path = Path(raw_data_path) + else: + raw_data_path = Path(__file__).parent / raw_data_path + print("Test raw data path") + return raw_data_path + + +@pytest.fixture(scope="session") +def tmp_processed_data_path(tmpdir_factory): + return Path(tmpdir_factory.mktemp("processed")) + + +@pytest.fixture(scope="session") +def tcga_assets_directory(): + specified_paths = Path(__file__).parent / "paths.json" + if specified_paths.exists(): + with open(specified_paths) as f: + if "assets_tcga" in json.load(f): + tcga_assets_directory = json.load(f)["assets_tcga"] + if tcga_assets_directory.startswith("/"): + return Path(tcga_assets_directory) + return Path(__file__).parent / tcga_assets_directory + + fedpydeseq2_datasets_dir = Path(fedpydeseq2_datasets.__file__).parent + return fedpydeseq2_datasets_dir / "assets/tcga" + + +@pytest.fixture(scope="session") +def local_processed_data_path(): + default_paths = Path(__file__).parent / "paths_default.json" + specified_paths = Path(__file__).parent / "paths.json" + found = False + if specified_paths.exists(): + with open(specified_paths) as f: + if "processed_data" in json.load(f): + found = True + processed_data_path = json.load(f)["processed_data"] + if not found: + with open(default_paths) as f: + processed_data_path = json.load(f)["processed_data"] + if processed_data_path.startswith("/"): + processed_data_path = Path(processed_data_path) + else: + processed_data_path = Path(__file__).parent / processed_data_path + return processed_data_path diff --git a/tests/deseq2_end_to_end/__init__.py b/tests/deseq2_end_to_end/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/deseq2_end_to_end/test_deseq2_pipe.py b/tests/deseq2_end_to_end/test_deseq2_pipe.py new file mode 100644 index 0000000..b7e83e8 --- /dev/null +++ b/tests/deseq2_end_to_end/test_deseq2_pipe.py @@ -0,0 +1,316 @@ +"""Module testing the logmeans computed in the DESeq2Strategy.""" +import os +from itertools import product + +import pytest +from fedpydeseq2_datasets.constants import TCGADatasetNames + +from .test_deseq2_pipe_utils import pipeline_to_test + +COOKS_FILTER = [False, True] + +DESIGN_FACTORS = ["stage", ["gender", "stage"]] + + +@pytest.mark.self_hosted_slow +@pytest.mark.parametrize( + "independent_filter, cooks_filter, design_factors", + [ + (independent_filter, cooks_filter, design_factors) + for independent_filter, cooks_filter, design_factors in product( + [True, False], COOKS_FILTER, DESIGN_FACTORS + ) + if design_factors == "stage" or (not independent_filter and not cooks_filter) + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_simulation_mode_small_samples( + independent_filter: bool, + cooks_filter: bool, + design_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Compare FL and pooled deseq2 pipelines. + + The data is TCGA-LUAD, restricted to a small number of samples. + + Parameters + ---------- + independent_filter: bool + If true, use the independent filtering step. If not, use standard + p-value adjustment. + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + design_factors: str or list + The design factors to use. + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + pipeline_to_test( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + small_samples=True, + simulate=True, + cooks_filter=cooks_filter, + independent_filter=independent_filter, + only_two_centers=False, + design_factors=design_factors, + ) + + +@pytest.mark.parametrize( + "only_two_centers, independent_filter, cooks_filter, design_factors", + [ + (only_two_centers, independent_filter, cooks_filter, design_factors) + for ( + only_two_centers, + independent_filter, + cooks_filter, + design_factors, + ) in product([True, False], [True, False], COOKS_FILTER, DESIGN_FACTORS) + if ( + only_two_centers + and independent_filter + and cooks_filter + and design_factors == "stage" + ) + or ( + not only_two_centers + and (design_factors == "stage" or (cooks_filter and independent_filter)) + ) + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_simulation_mode_small_genes( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + only_two_centers, + independent_filter, + cooks_filter, + design_factors, +): + """Compare FL and pooled deseq2 pipelines. + + The data is TCGA-LUAD, restricted to a small number of genes. + + Parameters + ---------- + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + only_two_centers: bool + If true, restrict the data to two centers. + independent_filter: bool + If true, use the independent filtering step. If not, use standard + p-value adjustment. + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + design_factors : str or list + The design factors to use. + """ + pipeline_to_test( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + small_genes=True, + simulate=True, + only_two_centers=only_two_centers, + independent_filter=independent_filter, + cooks_filter=cooks_filter, + refit_cooks=True, + design_factors=design_factors, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors, contrast", + [ + (["stage", "gender"], None, ["stage", "Advanced", "Non-advanced"]), + (["stage", "gender", "CPE"], ["CPE"], ["CPE", "", ""]), + (["stage", "gender", "CPE"], ["CPE"], ["gender", "female", "male"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_multifactor_on_simulation_mode_small_genes( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, + contrast, +): + """Compare FL and pooled deseq2 pipelines. + + The data is TCGA-LUAD, restricted to a small number of genes. + + Parameters + ---------- + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + design_factors: str or list[str] + The design factors to use. + continuous_factors: list[str] or None + The continuous factors to use. + contrast: list[str] or None + The contrast to use. + """ + pipeline_to_test( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + small_genes=True, + simulate=True, + only_two_centers=True, + independent_filter=True, + cooks_filter=True, + refit_cooks=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + contrast=contrast, + ) + + +@pytest.mark.parametrize( + "cooks_filter, design_factors", list(product(COOKS_FILTER, DESIGN_FACTORS)) +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_simulation_mode_small_genes_refit_cooks_false( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + cooks_filter, + design_factors, +): + """Compare FL and pooled deseq2 pipelines. + + The data is TCGA-LUAD, restricted to a small number of genes. + + Parameters + ---------- + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + design_factors : str or list + The design factors to use. + """ + pipeline_to_test( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + small_genes=True, + simulate=True, + only_two_centers=True, + independent_filter=True, + cooks_filter=cooks_filter, + refit_cooks=False, + design_factors=design_factors, + ) + + +@pytest.mark.docker +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_docker_mode_on_self_hosted_small_genes( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + # This is a workaround to avoid the docker daemon issue. + os.environ["DOCKER_HOST"] = "unix:///run/user/1000/docker.sock" + pipeline_to_test( + raw_data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + simulate=False, + backend="docker", + small_genes=True, + cooks_filter=True, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.parametrize( + "dataset_name", + ["TCGA-LUAD", "TCGA-PAAD"], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_subprocess_mode_on_self_hosted_small_genes_small_samples( + dataset_name: TCGADatasetNames, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Compare FL and pooled deseq2 pipelines. + + The data is TCGA-LUAD on a self hosted runner. + + Parameters + ---------- + dataset_name: TCAGDatasetNames + The name of the dataset, for example "TCGA-LUAD". + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + # Get the ground truth. + pipeline_to_test( + raw_data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + assets_directory=tcga_assets_directory, + dataset_name=dataset_name, + simulate=False, + independent_filter=True, + cooks_filter=True, + backend="subprocess", + small_genes=True, + small_samples=True, + only_two_centers=True, + ) diff --git a/tests/deseq2_end_to_end/test_deseq2_pipe_local.py b/tests/deseq2_end_to_end/test_deseq2_pipe_local.py new file mode 100644 index 0000000..32968ca --- /dev/null +++ b/tests/deseq2_end_to_end/test_deseq2_pipe_local.py @@ -0,0 +1,351 @@ +"""Module testing the logmeans computed in the DESeq2Strategy.""" +from itertools import product + +import pytest + +from .test_deseq2_pipe_utils import pipeline_to_test + +COOKS_FILTER = [True, False] + + +@pytest.mark.local +@pytest.mark.parametrize( + "only_two_centers, independent_filter, cooks_filter", + list(product([True], [True], [True])), +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_subprocess_mode_local_small_genes( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + only_two_centers, + independent_filter, + cooks_filter, +): + """Compare FL and pooled deseq2 pipelines. + + This test is here to be able to locally test the pipeline on non simulated data. + + The data is TCGA-LUAD. + + Parameters + ---------- + raw_data_path : Path + The path to the root data. + local_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + only_two_centers: bool + If true, restrict the data to two centers. + independent_filter: bool + If true, use the independent filtering step. If not, use standard + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + + """ + pipeline_to_test( + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + simulate=False, + small_genes=True, + cooks_filter=cooks_filter, + only_two_centers=only_two_centers, + independent_filter=independent_filter, + ) + + +@pytest.mark.local +@pytest.mark.parametrize( + "independent_filter, cooks_filter", + product([True, False], [True, False]), +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_simu_mode_local( + independent_filter: bool, + cooks_filter: bool, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + """Compare FL and pooled deseq2 pipelines. + + This test is here to be able to locally test the pipeline on non simulated data. + + The data is TCGA-LUAD. + + Parameters + ---------- + independent_filter: bool + If true, use the independent filtering step. If not, use standard + p-value adjustment. + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + raw_data_path : Path + The path to the root data. + local_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + pipeline_to_test( + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + simulate=True, + small_genes=False, + cooks_filter=cooks_filter, + independent_filter=independent_filter, + ) + + +@pytest.mark.local +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_subprocess_mode_local( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + """Compare FL and pooled deseq2 pipelines. + + This test is here to be able to locally test the pipeline on non simulated data. + + The data is TCGA-LUAD. + + Parameters + ---------- + raw_data_path : Path + The path to the root data. + local_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + pipeline_to_test( + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + simulate=False, + small_genes=False, + small_samples=False, + clean_models=False, + only_two_centers=True, + ) + + +@pytest.mark.local +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_subprocess_mode_local_on_self_hosted_keep_models( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Compare FL and pooled deseq2 pipelines. + + This test is here to be able to locally test the pipeline on non simulated data. + + The data is TCGA-LUAD. + + Parameters + ---------- + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + pipeline_to_test( + raw_data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + simulate=False, + small_genes=False, + small_samples=False, + clean_models=False, + only_two_centers=True, + ) + + +@pytest.mark.local +@pytest.mark.parametrize( + "design_factors, continuous_factors, contrast", + [ + (["stage", "gender"], None, ["stage", "Advanced", "Non-advanced"]), + (["stage", "gender", "CPE"], ["CPE"], ["gender", "female", "male"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_multifactor_on_subprocess_mode_local_on_self_hosted_keep_models( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, + contrast, +): + """Compare FL and pooled deseq2 pipelines. + + This test is here to be able to locally test the pipeline on non simulated data. + The keep_models flag is set to True, in order to evaluate the memory usage. + + The data is TCGA-LUAD. + + Parameters + ---------- + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + design_factors: str or list[str] + The design factors to use. + continuous_factors: list[str] or None + The continuous factors to use. + contrast: list[str] or None + The contrast to use. + + """ + pipeline_to_test( + raw_data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + simulate=False, + small_genes=False, + small_samples=False, + clean_models=False, + only_two_centers=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + contrast=contrast, + ) + + +@pytest.mark.local +@pytest.mark.parametrize( + "independent_filter, cooks_filter", + product([True, False], COOKS_FILTER), +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_simulation_mode_local_small_genes( + independent_filter: bool, + cooks_filter: bool, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Compare FL and pooled deseq2 pipelines. + + The data is TCGA-LUAD, restricted to a small number of genes. + + Parameters + ---------- + independent_filter: bool + If true, use the independent filtering step. If not, use standard + p-value adjustment. + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + pipeline_to_test( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + small_genes=True, + small_samples=False, + cooks_filter=cooks_filter, + simulate=True, + only_two_centers=False, + independent_filter=independent_filter, + ) + + +@pytest.mark.local +@pytest.mark.parametrize( + "independent_filter, cooks_filter", + product( + [True, False], + COOKS_FILTER, + ), +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_end_to_end_on_simulation_mode_local_small_samples( + independent_filter: bool, + cooks_filter: bool, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Compare FL and pooled deseq2 pipelines. + + The data is TCGA-LUAD, restricted to a small number of genes. + + Parameters + ---------- + independent_filter: bool + If true, use the independent filtering step. If not, use standard + p-value adjustment. + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + raw_data_path : Path + The path to the root data. + tmp_processed_data_path : Path + The path to the processed data. The subdirectories will + be created if needed. + tcga_assets_directory : Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + pipeline_to_test( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + small_genes=False, + small_samples=True, + simulate=True, + only_two_centers=False, + independent_filter=independent_filter, + ) diff --git a/tests/deseq2_end_to_end/test_deseq2_pipe_utils.py b/tests/deseq2_end_to_end/test_deseq2_pipe_utils.py new file mode 100644 index 0000000..ab286e0 --- /dev/null +++ b/tests/deseq2_end_to_end/test_deseq2_pipe_utils.py @@ -0,0 +1,109 @@ +"""Main utilities to test the deseq2 pipeline.""" + +from pathlib import Path +from typing import Literal + +from fedpydeseq2 import DESeq2Strategy +from tests.tcga_testing_pipe import run_tcga_testing_pipe + + +def pipeline_to_test( + raw_data_path: Path, + processed_data_path: Path, + assets_directory: Path, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + independent_filter: bool = True, + cooks_filter: bool = True, + refit_cooks: bool = True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + continuous_factors: list[str] | None = None, + contrast: list[str] | None = None, + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] | None = None, + lfc_null: float = 0.0, + reference_dds_ref_level: tuple[str, ...] | None = ("stage", "Advanced"), + clean_models: bool = True, +): + """Compare FL and pooled deseq2 pipelines. + + Parameters + ---------- + raw_data_path: Path + The path to the root data . + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + independent_filter: bool + If true, use the independent filtering step. If not, use standard + p-value adjustment. + cooks_filter: bool + If true, the Cook's filtering is applied at the end of the pipeline. + refit_cooks: bool + If true, refit the Cook's filtering. + backend: str + The backend to use. Either "subprocess" or "docker". + only_two_centers: bool + If true, restrict the data to two centers. + design_factors: str or list + The design factors to use. + ref_levels: dict or None + The reference levels of the design factors. + continuous_factors: list or None + The continuous factors to use. + contrast: list or None + The contrast to use. + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] or None + The alternative hypothesis to use. + lfc_null: float + The null log fold change. + reference_dds_ref_level: tuple or None + The reference level of the design factors. + clean_models: bool + Whether to clean the models after the computation. + """ + # Run the tcga experiment to check convergence + run_tcga_testing_pipe( + DESeq2Strategy( + design_factors=design_factors, + ref_levels=ref_levels, + independent_filter=independent_filter, + cooks_filter=cooks_filter, + refit_cooks=refit_cooks, + continuous_factors=continuous_factors, + contrast=contrast, + alt_hypothesis=alt_hypothesis, + lfc_null=lfc_null, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + clean_models=clean_models, + ) diff --git a/tests/paths_default.json b/tests/paths_default.json new file mode 100644 index 0000000..c5312a6 --- /dev/null +++ b/tests/paths_default.json @@ -0,0 +1 @@ +{"raw_data": "../data/raw", "processed_data": "../data/processed"} diff --git a/tests/tcga_testing_pipe.py b/tests/tcga_testing_pipe.py new file mode 100644 index 0000000..ceb3681 --- /dev/null +++ b/tests/tcga_testing_pipe.py @@ -0,0 +1,181 @@ +"""Implements a function running a substrafl experiment with tcga dataset.""" + +from datetime import datetime +from pathlib import Path + +import pandas as pd +from fedpydeseq2_datasets.constants import TCGADatasetNames +from fedpydeseq2_datasets.create_reference_dds import setup_tcga_ground_truth_dds +from fedpydeseq2_datasets.process_and_split_data import setup_tcga_dataset +from fedpydeseq2_datasets.utils import get_experiment_id +from substra.sdk.schemas import BackendType +from substrafl import ComputePlanBuilder + +from fedpydeseq2.substra_utils.federated_experiment import run_federated_experiment + + +def run_tcga_testing_pipe( + strategy: ComputePlanBuilder, + raw_data_path: Path, + processed_data_path: Path, + assets_directory: Path, + backend: BackendType = "subprocess", + simulate: bool = True, + dataset_name: TCGADatasetNames = "TCGA-LUAD", + small_samples: bool = False, + small_genes: bool = False, + only_two_centers: bool = False, + register_data: bool = False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + reference_dds_ref_level: tuple[str, ...] | None = ("stage", "Advanced"), + refit_cooks: bool = False, + remote_timeout: int = 86400, # 24 hours + clean_models: bool = True, +) -> dict: + """Runa tcga experiment using the given substrafl strategy. + + The raw_data_path is expected to have the following structure: + ``` + + ├── tcga + │ ├── _clinical.tsv.gz + │ └── _raw_RNAseq.tsv.gz + └── + ``` + + The processed_data_path will be created if it does not exist and + will contain the following structure: + + ``` + + ├── centers_data + │ └── tcga + │ └── + │ └── center_0 + │ ├── counts_data.csv + │ ├── metadata.csv + │ └── ground_truth_dds.pkl + ├── pooled_data + │ └── tcga + │ └── + │ ├── counts_data.csv + │ ├── metadata.csv + │ └── ground_truth_dds.pkl + └── + + ``` + + Parameters + ---------- + strategy : ComputePlanBuilder + + raw_data_path : Path + The path to the raw tcga data. must contain a folder "tcga" with the structure + described above. + processed_data_path : Path + The path to the processed data. The subfolders will be created if does not + exist. + assets_directory : Path + The path to the assets directory. Is expected to contain the opener.py and + description.md files. + backend : str, optional + 'docker', or 'subprocess'. (Default='subprocess'). + simulate : bool, optional + If True, the experiment is simulated. (Default=True). + dataset_name : Literal["TCGA-LUAD"], optional + The dataset to preprocess, by default "TCGA-LUAD". + small_samples : bool, optional + If True, only preprocess a small subset of the data, by default False. + The number of samples is reduced to 10 per center. + small_genes : bool, optional + If True, only preprocess a small subset of the genes, by default False. + The number of genes is reduced to 100. + only_two_centers : bool, optional + If True, merged the centers into two centers, by default False. + register_data : bool, optional + If True, register the data to substra. Otherwise, use pre-registered dataset + and data samples. By default False. + design_factors : Union[str, list[str]] + Name of the columns of metadata to be used as design variables. + For now, only "stage", "gender" and "CPE" are supported. + continuous_factors : list[str] or None + The factors which are continuous. + reference_dds_ref_level : tuple[str, ...] or None + The reference levels for the design factors. If None, the first level is used. + refit_cooks : bool + If True, refit the model after removing the Cook's distance outliers. + remote_timeout : int + The timeout for the remote experiment in seconds. + This means that we wait for at most `remote_timeout` seconds for the experiment + to finish. If the experiment is not finished after this time, we raise an error. + The default is 86400 s (24h). + clean_models : bool + If True, clean the models after the experiment. (Default=True). + + + Returns + ------- + dict + Result of the strategy, which are assumed to be contained in the + results attribute of the last round of the aggregation node. + """ + current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + compute_plan_name = f"FedPyDESeq2_{dataset_name}_{current_datetime}" + + print("Setting up TCGA dataset...") + setup_tcga_dataset( + raw_data_path, + processed_data_path, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + print("Setting up TCGA ground truth DESeq2 datasets...") + setup_tcga_ground_truth_dds( + processed_data_path, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + default_refit_cooks=refit_cooks, + ) + + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + metadata = pd.read_csv( + processed_data_path / "pooled_data" / "tcga" / experiment_id / "metadata.csv" + ) + n_centers = len(metadata.center_id.unique()) + + fl_results = run_federated_experiment( + strategy=strategy, + register_data=register_data, + n_centers=n_centers, + backend=backend, + simulate=simulate, + centers_root_directory=processed_data_path + / "centers_data" + / "tcga" + / experiment_id, + assets_directory=assets_directory, + compute_plan_name=compute_plan_name, + dataset_name=dataset_name, + remote_timeout=remote_timeout, + clean_models=clean_models, + save_filepath=None, + ) + + return fl_results diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/deseq2_core/__init__.py b/tests/unit_tests/deseq2_core/__init__.py new file mode 100644 index 0000000..ce03f44 --- /dev/null +++ b/tests/unit_tests/deseq2_core/__init__.py @@ -0,0 +1 @@ +"""Module to test the DESeq2 pipeline.""" diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/__init__.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/__init__.py new file mode 100644 index 0000000..09da94f --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/__init__.py @@ -0,0 +1 @@ +"""Module to test the deseq2_lfc_dispersions module.""" diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py new file mode 100644 index 0000000..c950fbe --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/__init__.py @@ -0,0 +1,15 @@ +"""Module to test the different components for computing genewise dispersions. + +In the single factor case, it is possible to perform a test on the whole block. + +However, for the multifactor case, we must decompose the test into four smaller +steps that are individually unit tested, because they are sensitive to the propagation +of intermediate results. + +- The first step is to compute the MoM dispersions, which is theoretically independent +of the number of design factors. +- The second step is to compute the mu_hat estimate. +- The third step is to compute the number of replicates. +- The fourth step is to compute the dispersion from mu_hat and the number of replicates. + +""" diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_MoM_dispersions.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_MoM_dispersions.py new file mode 100644 index 0000000..e8a924a --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_MoM_dispersions.py @@ -0,0 +1,454 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.compute_MoM_dispersions import ( # noqa: E501 + ComputeMoMDispersions, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_MoM_dispersions_on_small_genes_small_samples( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + MoM_dispersions_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_MoM_dispersions_on_small_samples_on_self_hosted_fast( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + MoM_dispersions_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_MoM_dispersions_on_self_hosted_slow( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + MoM_dispersions_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def MoM_dispersions_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for the MoM dispersions. + + Starting with the same size factors as the reference dataset, compute genewise + dispersions and compare the results with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = ( + local_processed_data_path / "centers_data" / "tcga" / experiment_id + ) + # Get FL results. + fl_results = run_tcga_testing_pipe( + MoMDispersionsTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=complete_ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + local_processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Tests that the MoM dispersions are close + + assert np.allclose( + fl_results["MoM_dispersions"], + pooled_dds.varm["_MoM_dispersions"], + equal_nan=True, + ) + + +class MoMDispersionsTester(UnitTester, ComputeMoMDispersions, AggPassOnResults): + """A class to implement a unit test for the genewise dispersions. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + min_disp : float + Lower threshold for dispersion parameters. (default: ``1e-8``). + + max_disp : float + Upper threshold for dispersion parameters. + Note: The threshold that is actually enforced is max(max_disp, len(counts)). + (default: ``10``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + build_compute_plan + Build the computation graph to collect the MoM dispersions results. + + init_local_states + Initialize the local states. This method creates the local adata and + computes the local gram matrix and local features, which are necessary inputs + to compute MoM dispersions. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + min_disp: float = 1e-8, + max_disp: float = 10.0, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + min_disp=min_disp, + max_disp=max_disp, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + #### Define hyper parameters #### + + self.min_disp = min_disp + self.max_disp = max_disp + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + ( + local_states, + mom_dispersions_shared_state, + round_idx, + ) = self.compute_MoM_dispersions( + train_data_nodes, + aggregation_node, + local_states, + gram_features_shared_states=shared_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=[mom_dispersions_shared_state], + description="Save MoM dispersions", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Initialize the local states. + + Here, we copy the reference dds to the local state, and + create the local gram matrix and local features, which are necessary inputs + to the genewise dispersions computation. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. + shared_state : Any + Shared state. Not used. + + Returns + ------- + dict + Local states containing a "local_gran_matrix" and a "local_features" fields. + These fields are used to compute the rough dispersions, and are computed in + the last step of the compute_size_factors step in the main pipe. + """ + + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + # This field is not saved in pydeseq2 but used in fedpyseq2 + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + design = self.local_adata.obsm["design_matrix"].values + + return { + "local_gram_matrix": design.T @ design, + "local_features": design.T @ self.local_adata.layers["normed_counts"], + } diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_compute_mu_hat.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_compute_mu_hat.py new file mode 100644 index 0000000..3135f78 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_compute_mu_hat.py @@ -0,0 +1,310 @@ +"""Unit tests for the computation of mu hat. + +Here, we test the computation of mu hat using IRLS, in the +compute_genewise_dispersions step of the DESeq2 algorithm. + +Note that mu hat can be computed in two ways in the compute_genewise_dispersions step. +- if num_vars == num_levels, then mu_hat is computed as the solution to a linear system +- otherwise, mu_hat is computed using IRLS + +In this file, we only test THE SECOND CASE, as the algorithm we use is not +the same as the one used in the pooled setting. + +""" + +import pytest + +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_lfc.compute_lfc_test_pipe import ( # noqa: E501 + pipe_test_compute_lfc, +) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_mu_hat_small_genes( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if mu_hat is working as expected. + + This test focuses on a small number of genes, to see if the algorithm is working + as expected in a fast way. + + Recall that computing mu hat is the second step in the compute_genewise_dispersions + step, after computing the MoM dispersions and before computing the dispersions + estimates. + + Parameters + ---------- + design_factors: Union[str, List[str]] + The design factors. + + continuous_factors: Union[str, List[str]] + The continuous factors. + + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + lfc_mode="mu_init", + data_path=raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + nll_rtol=0.02, + tolerated_failed_genes=1, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + # (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_mu_hat_small_samples( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if computing mu hat is working as expected. + + This test focuses on a small number of samples, to see if the algorithm is working + as expected in the self hosted CI. Note that on a small number of samples, the + algorithm is less performant then when there are more samples (see the + test_mu_hat test). This can be explained by the fact that the log likelihood + is somehow less smooth when there are few data points. + + Note that for a reason that is not clear, for IRLS converged genes, a tolerance + of 1e-5 is too hard (even if theoretically, the algorithm is pooled equivalent). + However, the results are still quite close to the pooled ones, and we do not + investigate this further for now. + + Parameters + ---------- + design_factors: Union[str, List[str]] + The design factors. + + continuous_factors: Union[str, List[str]] + The continuous factors. + + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + lfc_mode="mu_init", + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + nll_rtol=0.02, + tolerated_failed_genes=15, + rtol_irls=1e-3, + atol_irls=1e-5, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_mu_hat_luad( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if computing mu hat is working as expected. + + This test focuses on a large number of samples and genes, to see if the algorithm + is working as expected in the self hosted CI. + + Note that a relative tolerance of 1e-3 is used, instead of the default 1e-5. The + reasons for which this is needed are not clear, but the results are still quite + close to the pooled ones, and we do not investigate this further for now. + + Parameters + ---------- + design_factors: Union[str, List[str]] + The design factors. + + continuous_factors: Union[str, List[str]] + The continuous factors. + + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + lfc_mode="mu_init", + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + nll_rtol=0.02, + tolerated_failed_genes=5, + rtol_irls=1e-3, + atol_irls=1e-5, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_mu_hat_paad( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if computing mu hat is working as expected. + + This test focuses on a large number of samples and genes, to see if the algorithm + is working as expected in the self hosted CI. + + Note that we do not add CPE as a continuous factor, as it is not present in the + PAAD dataset. + + Note that a relative tolerance of 1e-2 is used, instead of the default 1e-5. The + reasons for which this is needed are not clear, but the results are still quite + close to the pooled ones, and we do not investigate this further for now. + + Parameters + ---------- + design_factors: Union[str, List[str]] + The design factors. + + continuous_factors: Union[str, List[str]] + The continuous factors. + + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + lfc_mode="mu_init", + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-PAAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + nll_rtol=0.02, + tolerated_failed_genes=5, + rtol_irls=1e-2, + atol_irls=1e-2, + ) diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_dispersions_from_mu_hat.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_dispersions_from_mu_hat.py new file mode 100644 index 0000000..fc81752 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_dispersions_from_mu_hat.py @@ -0,0 +1,861 @@ +"""Module to test the final step of the compute_genewise_dispersions step. + +This step tests the final substep which is to estimate the genewise dispersions +by minimizing the negative binomial likelihood, with a fixed value of +the mean parameter given by the mu_hat estimate. + +""" +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms import ComputeDispersionsGridSearch +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.utils_genewise_dispersions import ( # noqa: E501 + perform_dispersions_and_nll_relative_check, +) +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +def test_dispersions_from_mu_hat_on_small_genes_small_samples( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, +): + """Perform a unit test for the genewise dispersions. + + This test is performed on a small number of genes and samples, in order to + be fast and run on the github CI. + + Note that in this test, we tolerate 0 failed genes. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + """ + dispersions_from_mu_hat_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors=design_factors, + continuous_factors=continuous_factors, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors, tolerated_failed_genes", + [ + ("stage", None, 0), + (["stage", "gender"], None, 0), + (["stage", "gender", "CPE"], ["CPE"], 1), + ], +) +def test_dispersions_from_mu_hat_on_small_samples_on_self_hosted_fast( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, + tolerated_failed_genes, +): + """Perform a unit test for the genewise dispersions. + + This test is performed on a small number of samples, in order to be fast. + However, all genes are used, so that we can have a clearer statistical vision + of the failing cases (50 000 genes). + + In only one case (the last one), we authorized one failed gene. We have not + investigated further why it fails. Our guess is that the underlying assumption + for the grid search we perform (with multiple steps) is that the nll + decreases then increases. This is perhaps not always true (we have no theoretical + guarantee), or perhaps the nll is very sharp. In any case, it is not suprising + that such things happen in a case where we have very few samples (samples smooth + the losses). + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + tolerated_failed_genes: int + The number of genes that are allowed to fail the relative nll criterion. + + """ + dispersions_from_mu_hat_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +def test_dispersions_from_mu_hat_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, +): + """Perform a unit test for the genewise dispersions. + + This test is performed on the full dataset, in order to have a more + realistic view of the performance of the algorithm, on a self hosted + runner and using the TCGA-LUAD dataset. + + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + """ + dispersions_from_mu_hat_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + ], +) +def test_dispersions_from_mu_hat_paad_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, +): + """Perform a unit test for the genewise dispersions. + + This test is performed on the full dataset, in order to have a more + realistic view of the performance of the algorithm, on a self hosted + runner and using the TCGA-PAAD dataset. + + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + """ + dispersions_from_mu_hat_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-PAAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def dispersions_from_mu_hat_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, + rtol: float = 0.02, + atol: float = 1e-3, + nll_rtol: float = 0.02, + nll_atol: float = 1e-3, + tolerated_failed_genes: int = 0, +): + """Perform a unit test for the genewise dispersions. + + This unit test only concerns the last step of the genewise dispersions fitting, + that is fitting the dispersions from the mu hat estimate by minimizing the + negative binomial likelihood function (as a function of the dispersion only). + + We start from all quantities defined in the pooled data. We then fit the genewise + dispersions using the mu hat estimate as the mean parameter. We then compare the + FL dispersions to the pooled dispersions. If the relative error is above 2%, + we check the likelihoods. If the likelihoods are above 2% higher than the pooled + likelihoods, we fail the test (with a certain tolerance for failure). + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + + rtol: float + The relative tolerance for between the FL and pooled dispersions. + + atol: float + The absolute tolerance for between the FL and pooled dispersions. + + nll_rtol: float + The relative tolerance for between the FL and pooled likelihoods, in the + case where the dispersions are above the tolerance. + + nll_atol: float + The absolute tolerance for between the FL and pooled likelihoods, in the + case where the dispersions are above the tolerance. + + tolerated_failed_genes: int + The number of genes that are allowed to fail the relative nll criterion. + """ + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = ( + local_processed_data_path / "centers_data" / "tcga" / experiment_id + ) + # Get FL results. + fl_results = run_tcga_testing_pipe( + GenewiseDispersionsFromMuHatTester( + design_factors=design_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + continuous_factors=continuous_factors, + ), + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + local_processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Compute relative error + + perform_dispersions_and_nll_relative_check( + fl_results["genewise_dispersions"], + pooled_dds, + rtol=rtol, + atol=atol, + nll_rtol=nll_rtol, + nll_atol=nll_atol, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +class GenewiseDispersionsFromMuHatTester(UnitTester, ComputeDispersionsGridSearch): + """A class to implement a unit test for the genewise dispersions fitting. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + min_disp : float + Lower threshold for dispersion parameters. (default: ``1e-8``). + + max_disp : float + Upper threshold for dispersion parameters. + Note: The threshold that is actually enforced is max(max_disp, len(counts)). + (default: ``10``). + + grid_batch_size : int + The number of genes to put in each batch for local parallel processing. + (default: ``100``). + + grid_depth : int + The number of grid interval selections to perform (if using GridSearch). + (default: ``3``). + + grid_length : int + The number of grid points to use for the grid search (if using GridSearch). + (default: ``100``). + + num_jobs : int + The number of jobs to use for local parallel processing in MLE tasks. + (default: ``8``). + + joblib_verbosity : int + Verbosity level for joblib. (default: ``3``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + build_compute_plan + Build the computation graph to run and save the genewise dispersions. + + save_genewise_dispersions_checkpoint + Save genewise dispersions checkpoint. + + init_local_states + A local method. + Copy the reference dds to the local state and add the non_zero mask. + + sum_num_samples + An aggregation method. + Compute the total number of samples to set max_disp. + + set_max_disp + A local method. + Set max_disp using the total number of samples in the study. + + get_local_dispersions + A local method. + Collect dispersions and pass on. + + pass_on_results + An aggregation method. + Set the genewise dispersions in the results. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + min_disp: float = 1e-8, + max_disp: float = 10.0, + grid_batch_size: int = 250, + grid_depth: int = 3, + grid_length: int = 100, + num_jobs=8, + joblib_backend: str = "loky", + joblib_verbosity: int = 3, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + min_disp=min_disp, + max_disp=max_disp, + grid_batch_size=grid_batch_size, + grid_depth=grid_depth, + grid_length=grid_length, + num_jobs=num_jobs, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + joblib_backend=joblib_backend, + joblib_verbosity=joblib_verbosity, + ) + + #### Define hyper parameters #### + + self.min_disp = min_disp + self.max_disp = max_disp + self.grid_batch_size = grid_batch_size + self.grid_depth = grid_depth + self.grid_length = grid_length + self.num_jobs = num_jobs + + #### Define job parallelization parameters #### + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + + # Very important, we need to keep these layers as they cannot be recomputed. + self.layers_to_save_on_disk = { + "local_adata": ["_mu_hat"], + "refit_adata": None, + } + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run and save the genewise dispersions. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, init_shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + ### Aggregation step to compute the total number of samples ### + + shared_state, round_idx = aggregation_step( + aggregation_method=self.sum_num_samples, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=init_shared_states, + description="Get the total number of samples.", + round_idx=round_idx, + clean_models=clean_models, + ) + + ### Set max_disp using the total number of samples ### + + local_states, _, round_idx = local_step( + local_method=self.set_max_disp, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute max_disp", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Fit genewise dispersions #### + ( + local_states, + genewise_dispersions_shared_state, + round_idx, + ) = self.fit_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + shared_state=None, + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Save genewise dispersions #### + + self.save_genewise_dispersions_checkpoint( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + genewise_dispersions_shared_state=genewise_dispersions_shared_state, + round_idx=round_idx, + clean_models=False, + ) + + def save_genewise_dispersions_checkpoint( + self, + train_data_nodes, + aggregation_node, + local_states, + genewise_dispersions_shared_state, + round_idx, + clean_models, + ): + """Save genewise dispersions checkpoint. + + This method saves the genewise dispersions. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + genewise_dispersions_shared_state: dict + Contains the output shared state of "fit_genewise_dispersions" step, + which contains a "genewise_dispersions" field used in this test. + + round_idx: int + The current round. + + clean_models: bool + Whether to clean the models after the computation. + + """ + # ---- Get local estimates ---- # + + local_states, shared_states, round_idx = local_step( + local_method=self.get_local_dispersions, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=genewise_dispersions_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Get local dispersions", + round_idx=round_idx, + clean_models=clean_models, + ) + + # ---- Save dispersions in result ---- # + + results_shared_state, round_idx = aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Save genewise dispersions in results", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, round_idx + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds to the local state and add the non_zero mask. + + Returns a dictionary with the number of samples in the "num_samples" field, + in order to compute the max dispersion. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Not used. + + Returns + ------- + dict + The number of samples in the "num_samples" field. + """ + + self.local_adata = self.local_reference_dds.copy() + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + + return { + "num_samples": self.local_adata.n_obs, + } + + @remote + @log_remote + def sum_num_samples(self, shared_states): + """Compute the total number of samples to set max_disp. + + Parameters + ---------- + shared_states : list + List of initial shared states copied from the reference adata. + + Returns + ------- + dict + The total number of samples in the "tot_num_samples" field. + + """ + tot_num_samples = np.sum([state["num_samples"] for state in shared_states]) + return {"tot_num_samples": tot_num_samples} + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_max_disp(self, data_from_opener: ad.AnnData, shared_state: Any) -> dict: + """Set max_disp using the total number of samples in the study. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Not used. + """ + + self.local_adata.uns["max_disp"] = max( + self.max_disp, shared_state["tot_num_samples"] + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_dispersions( + self, data_from_opener: ad.AnnData, shared_state: dict + ) -> dict: + """Collect dispersions and pass on. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. + shared_state : dict + Shared state with the gene-wise dispersions. + + Returns + ------- + dict + Shared state the gene-wise dispersions. + """ + + return {"genewise_dispersions": shared_state["genewise_dispersions"]} + + @remote + @log_remote + def pass_on_results(self, shared_states: list): + """Set the genewise dispersions in the results. + + Parameters + ---------- + shared_states : list + List of shared states. The first element contains the genewise dispersions. + + """ + self.results = { + "genewise_dispersions": shared_states[0]["genewise_dispersions"], + } diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_genewise_dispersions_single_factor.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_genewise_dispersions_single_factor.py new file mode 100644 index 0000000..d936798 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_genewise_dispersions_single_factor.py @@ -0,0 +1,748 @@ +"""Module to test the genewise dispersions computation in a single factor case.""" + +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions import ( # noqa: E501 + ComputeGenewiseDispersions, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.utils_genewise_dispersions import ( # noqa: E501 + perform_dispersions_and_nll_relative_check, +) +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors", + [ + "stage", + ], +) +def test_genewise_dispersions_on_small_genes_small_samples( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors, +): + genewise_dispersions_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors=design_factors, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors", + [ + "stage", + ], +) +def test_genewise_dispersions_on_small_samples_on_self_hosted_fast( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, +): + genewise_dispersions_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize("design_factors", ["stage"]) +def test_genewise_dispersions_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, +): + genewise_dispersions_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def genewise_dispersions_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, + rtol=0.02, + atol=1e-3, + nll_rtol=0.02, + nll_atol=1e-3, + tolerated_failed_genes=0, +): + """Perform a unit test for the genewise dispersions. + + This steps tests the three quantities generated by the genewise dispersions fitting: + - the genewise dispersions themselves which are the endpoint; + - the method of moments dispersions (MoM) which are the starting point; + - the initial estimate of mu, the mean of the counts, which is an intermediate step. + + + Starting with the same size factors as the reference dataset, compute genewise + dispersions and compare the results with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + + rtol: float + The relative tolerance for between the FL and pooled dispersions. + + atol: float + The absolute tolerance for between the FL and pooled dispersions. + + nll_rtol: float + The relative tolerance for between the FL and pooled likelihoods, in the + + nll_atol: float + The absolute tolerance for between the FL and pooled likelihoods, in the + + tolerated_failed_genes: int + The number of genes that are allowed to fail the relative nll criterion. + + """ + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=None, + ) + + reference_data_path = ( + local_processed_data_path / "centers_data" / "tcga" / experiment_id + ) + # Get FL results. + fl_results = run_tcga_testing_pipe( + GenewiseDispersionsTester( + design_factors=design_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + local_processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Check that the MoM dispersions are close to the pooled ones + + assert np.allclose( + fl_results["MoM_dispersions"], + pooled_dds.varm["_MoM_dispersions"], + equal_nan=True, + ) + + # get sample ids for the pooled dds + sample_ids = pooled_dds.obs_names + sample_ids_fl = fl_results["sample_ids"] + # Get the permutation that sorts the sample ids + perm = np.argsort(sample_ids) + perm_fl = np.argsort(sample_ids_fl) + + fl_mu_hat = fl_results["mu_hat"][perm_fl] + pooled_mu_hat = pooled_dds.layers["_mu_hat"][perm] + + # Check that the nans are at the same places + assert np.all(np.isnan(fl_mu_hat) == np.isnan(pooled_mu_hat)) + + # Replace nans with 1. + fl_mu_hat[np.isnan(fl_mu_hat)] = 1.0 + pooled_mu_hat[np.isnan(pooled_mu_hat)] = 1.0 + + assert np.allclose( + np.sort(fl_results["mu_hat"].flatten()), + np.sort(pooled_dds.layers["_mu_hat"].flatten()), + equal_nan=True, + ) + + # Tests that the genewise dispersions are close to the pooled ones, or if not, + # that the adjusted log likelihood is close or better + + # Compute relative error + + # If any of the relative errors is above 2%, check likelihoods + perform_dispersions_and_nll_relative_check( + fl_results["genewise_dispersions"], + pooled_dds, + dispersions_param_name="genewise_dispersions", + rtol=rtol, + atol=atol, + nll_rtol=nll_rtol, + nll_atol=nll_atol, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +class GenewiseDispersionsTester( + UnitTester, + ComputeGenewiseDispersions, +): + """A class to implement a unit test for the genewise dispersions. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + min_disp : float + Lower threshold for dispersion parameters. (default: ``1e-8``). + + max_disp : float + Upper threshold for dispersion parameters. + Note: The threshold that is actually enforced is max(max_disp, len(counts)). + (default: ``10``). + + grid_batch_size : int + The number of genes to put in each batch for local parallel processing. + (default: ``100``). + + grid_depth : int + The number of grid interval selections to perform (if using GridSearch). + (default: ``3``). + + grid_length : int + The number of grid points to use for the grid search (if using GridSearch). + (default: ``100``). + + num_jobs : int + The number of jobs to use for local parallel processing in MLE tasks. + (default: ``8``). + + min_mu : float + Lower threshold for mean expression values. (default: ``0.5``). + + beta_tol : float + Tolerance for the beta coefficients. (default: ``1e-8``). + + max_beta : float + Upper threshold for the beta coefficients. (default: ``30``). + + irls_num_iter : int + Number of iterations for the IRLS algorithm. (default: ``20``). + + joblib_backend : str + The backend to use for the IRLS algorithm. (default: ``"loky"``). + + num_jobs : int + Number of CPUs to use for parallelization. (default: ``8``). + + joblib_verbosity : int + Verbosity level for joblib. (default: ``3``). + + irls_batch_size : int + Batch size for the IRLS algorithm. (default: ``100``). + + PQN_c1 : float + Parameter for the Proximal Newton algorithm. (default: ``1e-4``) (which + catches the IRLS algorithm). This is a line search parameter for the Armijo + condition. + + PQN_ftol : float + Tolerance for the Proximal Newton algorithm. (default: ``1e-7``). + + PQN_num_iters_ls : int + Number of iterations for the line search in the Proximal Newton algorithm. + (default: ``20``). + + PQN_num_iters : int + Number of iterations for the Proximal Newton algorithm. (default: ``100``). + + PQN_min_mu : float + Lower threshold for the mean expression values in + the Proximal Newton algorithm. + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + build_compute_plan + Build the computation graph to run and save the genewise dispersions. + + save_genewise_dispersions_checkpoints + Save genewise dispersions checkpoints, that is the genewise dispersions, + the MoM dispersions, and the mu_hat estimates. + + init_local_states + Initialize the local states, returning the local gram matrix and local features. + + get_local_mom_dispersions_mu_dispersions + Collect MoM dispersions and mu_hat estimate and pass on. + + concatenate_mu_estimates + Concatenate initial mu_hat estimates and pass on dispersions. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + min_disp: float = 1e-8, + max_disp: float = 10.0, + grid_batch_size: int = 250, + grid_depth: int = 3, + grid_length: int = 100, + num_jobs=8, + min_mu: float = 0.5, + beta_tol: float = 1e-8, + max_beta: float = 30, + irls_num_iter: int = 20, + joblib_backend: str = "loky", + joblib_verbosity: int = 3, + irls_batch_size: int = 100, + PQN_c1: float = 1e-4, + PQN_ftol: float = 1e-7, + PQN_num_iters_ls: int = 20, + PQN_num_iters: int = 100, + PQN_min_mu: float = 0.0, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + min_disp=min_disp, + max_disp=max_disp, + grid_batch_size=grid_batch_size, + grid_depth=grid_depth, + grid_length=grid_length, + num_jobs=num_jobs, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + min_mu=min_mu, + beta_tol=beta_tol, + max_beta=max_beta, + irls_num_iter=irls_num_iter, + joblib_backend=joblib_backend, + joblib_verbosity=joblib_verbosity, + irls_batch_size=irls_batch_size, + PQN_c1=PQN_c1, + PQN_ftol=PQN_ftol, + PQN_num_iters_ls=PQN_num_iters_ls, + PQN_num_iters=PQN_num_iters, + PQN_min_mu=PQN_min_mu, + ) + + #### Define hyper parameters #### + + self.min_disp = min_disp + self.max_disp = max_disp + self.grid_batch_size = grid_batch_size + self.grid_depth = grid_depth + self.grid_length = grid_length + self.num_jobs = num_jobs + self.min_mu = min_mu + self.beta_tol = beta_tol + self.max_beta = max_beta + + # Parameters of the IRLS algorithm + self.irls_num_iter = irls_num_iter + self.PQN_c1 = PQN_c1 + self.PQN_ftol = PQN_ftol + self.PQN_num_iters_ls = PQN_num_iters_ls + self.PQN_num_iters = PQN_num_iters + self.PQN_min_mu = PQN_min_mu + + #### Define job parallelization parameters #### + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run and save the genewise dispersions. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Fit genewise dispersions #### + ( + local_states, + genewise_dispersions_shared_state, + round_idx, + ) = self.fit_genewise_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + gram_features_shared_states=shared_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + self.save_genewise_dispersions_checkpoints( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + genewise_dispersions_shared_state=genewise_dispersions_shared_state, + round_idx=round_idx, + clean_models=False, + ) + + def save_genewise_dispersions_checkpoints( + self, + train_data_nodes, + aggregation_node, + local_states, + genewise_dispersions_shared_state, + round_idx, + clean_models, + ): + """Save genewise dispersions checkpoints. + + This method saves the genewise dispersions, the MoM dispersions, + and the mu_hat estimates. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + genewise_dispersions_shared_state: dict + Contains the output shared state of "fit_genewise_dispersions" step, + which contains a "genewise_dispersions" field used in this test. + + round_idx: int + The current round. + + clean_models: bool + Whether to clean the models after the computation. + + """ + # ---- Get MoM dispersions and mu_hat estimates ---- # + + local_states, shared_states, round_idx = local_step( + local_method=self.get_local_mom_dispersions_mu_dispersions, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=genewise_dispersions_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Get local MoM dispersions and mu_hat", + round_idx=round_idx, + clean_models=clean_models, + ) + + # ---- Concatenate mu_hat estimates ---- # + + results_shared_state, round_idx = aggregation_step( + aggregation_method=self.concatenate_mu_estimates, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Concatenate mu_hat estimates", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, round_idx + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Initialize the local states. + + Here, we copy the reference dds to the local state, and + create the local gram matrix and local features, which are necessary inputs + to the genewise dispersions computation. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. + shared_state : Any + Shared state. Not used. + + Returns + ------- + dict + Local states containing a "local_gran_matrix" and a "local_features" fields. + These fields are used to compute the rough dispersions, and are computed in + the last step of the compute_size_factors step in the main pipe. + """ + + self.local_adata = self.local_reference_dds.copy() + del self.local_adata.layers["_mu_hat"] + # This field is not saved in pydeseq2 but used in fedpyseq2 + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + design = self.local_adata.obsm["design_matrix"].values + + return { + "local_gram_matrix": design.T @ design, + "local_features": design.T @ self.local_adata.layers["normed_counts"], + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_mom_dispersions_mu_dispersions( + self, data_from_opener: ad.AnnData, shared_state: dict + ) -> dict: + """Collect MoM dispersions and mu_hat estimate and pass on. + + Here, we pass on the genewise dispersions present in the shared state, + and collect the MoM dispersions and local mu_hat estimates from the + local adata. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. + shared_state : dict + Shared state, containing the genewise dispersions in the + "genewise_dispersions" field. + + Returns + ------- + dict + Shared state containing the MoM dispersions in the "MoM_dispersions" + field, the local mu_hat estimates in the "local_mu_hat" field, and the + genewise dispersions in the "genewise_dispersions" field. + """ + + return { + "MoM_dispersions": self.local_adata.varm["_MoM_dispersions"], + "local_mu_hat": self.local_adata.layers["_mu_hat"], + "genewise_dispersions": shared_state["genewise_dispersions"], + "sample_ids": self.local_adata.obs_names, + "design_columns": self.local_adata.obsm["design_matrix"].columns, + } + + @remote + @log_remote + def concatenate_mu_estimates(self, shared_states: list): + """Concatenate initial mu_hat estimates and pass on dispersions. + + Parameters + ---------- + shared_states : list + A list of shared states with a "local_mu_hat" key containing the + local mu_hat estimates, a "MoM_dispersions" key containing the MoM + dispersions, and a "genewise_dispersions" key containing the genewise + dispersions. The MoM dispersions and gene-wise dispersions are passed on + and are supposed to be the same across all states. + + """ + mu_hat = np.vstack([state["local_mu_hat"] for state in shared_states]) + sample_ids = np.concatenate([state["sample_ids"] for state in shared_states]) + design_columns = shared_states[0]["design_columns"] + self.results = { + "mu_hat": mu_hat, + "MoM_dispersions": shared_states[0]["MoM_dispersions"], + "genewise_dispersions": shared_states[0]["genewise_dispersions"], + "sample_ids": sample_ids, + "design_columns": design_columns, + } diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_get_num_replicates.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_get_num_replicates.py new file mode 100644 index 0000000..398bddb --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/test_get_num_replicates.py @@ -0,0 +1,477 @@ +"""Module to test the GetNumReplicates class.""" + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates import ( # noqa: E501 + GetNumReplicates, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_get_num_replicates_on_small_genes_small_samples( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + """ + Test the GetNumReplicates class on a small number of genes and samples. + + Parameters + ---------- + design_factors : str or list + The design factors to use. + + continuous_factors : list or None + The continuous factors to use. + + raw_data_path : Path + The path to the raw data. + + local_processed_data_path : Path + The path to the processed data. + + tcga_assets_directory : Path + The path to the assets directory. + + """ + get_num_replicates_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_get_num_replicates_on_small_genes_on_self_hosted_fast( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """ + Test the GetNumReplicates class on a small number of genes. + + Parameters + ---------- + design_factors : str or list + The design factors to use. + + continuous_factors : list or None + The continuous factors to use. + + raw_data_path : Path + The path to the raw data. + + tmp_processed_data_path : Path + The path to the processed data. + + tcga_assets_directory : Path + The path to the assets directory. + + """ + get_num_replicates_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def get_num_replicates_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Check that the GetNumReplicates class is working correctly. + + More specifically, we check that the num_replicates field indeed corresponds + to the pandas value_counts of the full design matrix. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = ( + local_processed_data_path / "centers_data" / "tcga" / experiment_id + ) + # The test happens inside the last aggregation + run_tcga_testing_pipe( + GetNumReplicatesTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=complete_ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + +class GetNumReplicatesTester(UnitTester, GetNumReplicates): + """A class to implement a unit test for the GetNumReplicates class. + + This class checks that that num_replicates indeed corresponds to the + pandas value_counts of the full design matrix. + + To do so, we associate the num_replicate field to the corresponding + design matrix lines, before checking that the series obtained thus + matches the value_counts of the full design matrix. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + build_compute_plan + Build the computation graph to check the get_num_replicates function. + + get_local_cells_design + Get the local cells and design. + + agg_check_value_counts + Aggregate the value counts from the cells, design and num_replicates. + This function performs the following steps: + - Create the value_counts that result from the aggregation of the cells, design + and num_replicates, i.e., the value counts seen by the local data (in the + fl_count columns) + - Create the full design matrix and compute the real value_counts from the full + design matrix. + - Check that the value counts are the same, that is that get_num_replicates is + working correctly. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + #### Define hyper parameters #### + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to check the get_num_replicates function. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + ( + local_states, + round_idx, + ) = self.get_num_replicates( + train_data_nodes, aggregation_node, local_states, round_idx, clean_models + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.get_local_cells_design, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get local cells and design", + round_idx=round_idx, + clean_models=clean_models, + ) + aggregation_step( + aggregation_method=self.agg_check_value_counts, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Check matching value counts", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_cells_design(self, data_from_opener, shared_state: dict) -> dict: + """Get the local cells and design. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Not used + + Returns + ------- + dict + Dictionary with the following keys: + - cells: cells in the local data + - design: design matrix + - num_replicates: number of replicates + """ + cells = self.local_adata.obs["cells"] + design = self.local_adata.obsm["design_matrix"] + num_replicates = self.local_adata.uns["num_replicates"] + + return {"cells": cells, "design": design, "num_replicates": num_replicates} + + @remote + @log_remote + def agg_check_value_counts(self, shared_states: list[dict]): + """Aggregate the value counts from the cells, design and num_replicates. + + This function creates the value_counts that result from the + aggregation of the cells, design and num_replicates, i.e., the value + counts seen by the local data (in the fl_count columns) + + It then creates the full design matrix and computes the real value_counts + from the full design matrix. + + It then checks that the value counts are the same, that is that + get_num_replicates is working correctly. + + Parameters + ---------- + shared_states : list[dict] + List of shared states. Must contain the following keys: + - cells: cells in the local data + - design: design matrix + - num_replicates: number of replicates + + """ + design = pd.concat( + [shared_state["design"] for shared_state in shared_states], axis=0 + ) + cells = pd.concat( + [shared_state["cells"] for shared_state in shared_states], axis=0 + ) + num_replicates = shared_states[0]["num_replicates"] + design_columns = design.columns.tolist() + counts_col = pd.Series( + num_replicates.loc[cells].values, name="fl_count", index=cells.index + ) + extended_design = pd.concat([design, counts_col], axis=1) + # drop duplicates + extended_design = extended_design.drop_duplicates().reset_index(drop=True) + # On the other hand, compute value counts + value_counts = design.value_counts().reset_index() + # merge the two datasets + merged = pd.merge(value_counts, extended_design, on=design_columns, how="left") + # assert that the counts are the same + assert np.allclose(merged["fl_count"].values, merged["count"].values) diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/utils_genewise_dispersions.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/utils_genewise_dispersions.py new file mode 100644 index 0000000..a001ac7 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/utils_genewise_dispersions.py @@ -0,0 +1,118 @@ +import numpy as np + +from fedpydeseq2.core.utils import vec_loss + + +def perform_dispersions_and_nll_relative_check( + fl_dispersions, + pooled_dds, + dispersions_param_name: str = "genewise_dispersions", + prior_reg: bool = False, + rtol=0.02, + atol=1e-3, + nll_rtol=0.02, + nll_atol=1e-3, + tolerated_failed_genes=0, +): + """Perform the relative error check on the dispersions and likelihoods. + + This function checks the relative error on the dispersions. If the relative error + is above rtol, it checks the likelihoods. If the likelihoods are above rtol higher + than the pooled likelihoods, we fail the test (with a certain tolerance + for failure). + + Parameters + ---------- + fl_dispersions: np.ndarray + The dispersions computed by the federated algorithm. + + pooled_dds: DeseqDataSet + The pooled DeseqDataSet. + + dispersions_param_name: str + The name of the parameter in the varm that contains the dispersions. + + prior_reg: bool + If True, the prior regularization is applied. + + alpha_hat: Optional[np.ndarray] + The alpha_hat parameter for the prior regularization. + + prior_disp_var: float + The prior_disp_var parameter for the prior regularization. + + rtol: float + The relative tolerance for between the FL and pooled dispersions. + + atol: float + The absolute tolerance for between the FL and pooled dispersions. + + nll_rtol: float + The relative tolerance for between the FL and pooled likelihoods, in the + + nll_atol: float + The absolute tolerance for between the FL and pooled likelihoods, in the + + tolerated_failed_genes: int + The number of genes that are allowed to fail the relative nll criterion. + + """ + # If any of the relative errors is above 2%, check likelihoods + pooled_dispersions = pooled_dds.varm[dispersions_param_name] + accepted_error = np.abs(pooled_dispersions) * rtol + atol + absolute_error = np.abs(fl_dispersions - pooled_dispersions) + + if np.any(absolute_error > accepted_error): + to_check = absolute_error > accepted_error + print( + f"{to_check.sum()} genes do not pass the relative error criterion." + f" Genes that do not pass the relative error criterion with rtol {rtol} and" + f" atol {atol} are : " + ) + print(pooled_dds.var_names[to_check]) + + counts = pooled_dds[:, to_check].X + design = pooled_dds.obsm["design_matrix"].values + mu = pooled_dds[:, to_check].layers["_mu_hat"] + + if prior_reg: + alpha_hat = pooled_dds[:, to_check].varm["fitted_dispersions"] + prior_disp_var = pooled_dds.uns["prior_disp_var"] + else: + alpha_hat = None + prior_disp_var = None + + # Compute the likelihoods + fl_nll = vec_loss( + counts, + design, + mu, + fl_dispersions[to_check], + prior_reg=prior_reg, + alpha_hat=alpha_hat, + prior_disp_var=prior_disp_var, + ) + pooled_nll = vec_loss( + counts, + design, + mu, + pooled_dds[:, to_check].varm[dispersions_param_name], + prior_reg=prior_reg, + alpha_hat=alpha_hat, + prior_disp_var=prior_disp_var, + ) + + # Check that FL likelihood is smaller than pooled likelihood + nll_error = fl_nll - pooled_nll + nll_accepted_error = np.abs(pooled_nll) * nll_rtol + nll_atol + + failed_nll_criterion = nll_error > nll_accepted_error + + if np.sum(failed_nll_criterion) > 0: + print( + f"{failed_nll_criterion.sum()} genes do not pass the nll criterion." + f"The corresponding gene names are : " + ) + print(pooled_dds.var_names[to_check][failed_nll_criterion]) + + assert np.sum(failed_nll_criterion) <= tolerated_failed_genes diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py new file mode 100644 index 0000000..cf3905d --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/__init__.py @@ -0,0 +1 @@ +"""Contains unit tests for the compute_lfc module.""" diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_test_pipe.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_test_pipe.py new file mode 100644 index 0000000..ede7741 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_test_pipe.py @@ -0,0 +1,569 @@ +import pickle as pkl +from pathlib import Path +from typing import Literal + +import numpy as np +import pandas as pd +from fedpydeseq2_datasets.constants import TCGADatasetNames +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from pydeseq2.dds import DeseqDataSet +from scipy.linalg import solve # type: ignore +from substra import BackendType + +from fedpydeseq2.core.utils import vec_loss +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_lfc.compute_lfc_tester import ( # noqa: E501 + ComputeLFCTester, +) +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels + + +def pipe_test_compute_lfc( + data_path: Path, + processed_data_path: Path, + tcga_assets_directory: Path, + dataset_name: TCGADatasetNames = "TCGA-LUAD", + small_samples: bool = False, + small_genes: bool = False, + simulate: bool = True, + backend: BackendType = "subprocess", + only_two_centers: bool = False, + lfc_mode: Literal["lfc", "mu_init"] = "lfc", + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, + PQN_min_mu: float = 0.0001, + rtol_irls: float = 1e-5, + atol_irls: float = 1e-8, + rtol_pqn: float = 2e-2, + atol_pqn: float = 1e-3, + nll_rtol: float = 0.02, + nll_atol: float = 1e-3, + tolerated_failed_genes: int = 5, +): + r"""Perform a unit test for the log fold change computation. + + Parameters + ---------- + data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: BackendType + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + lfc_mode: str + The mode of the ComputeLFC algorithm. + + design_factors: str or list + The design factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + continuous_factors : list or None + The continuous factors to use. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + + PQN_min_mu : float + The minimum mu in the prox newton method. + + tolerated_failed_genes: int + The number of tolerated failed genes. + + rtol_irls: float + The relative tolerance for the comparison of the LFC and mu values for genes + where IRLS converges. + + atol_irls: float + The absolute tolerance for the comparison of the LFC and mu values for genes + where IRLS converges. + + rtol_pqn: float + The relative tolerance for the comparison of the LFC and mu values for genes + where PQN is used. + + atol_pqn: float + The absolute tolerance for the comparison of the LFC and mu values for genes + where PQN is used. + + nll_rtol: float + The relative tolerance for the comparison of the nll values, when the LFC and mu + values are too different. + + nll_atol: float + The absolute tolerance for the comparison of the nll values, when the LFC and mu + values are too different. + + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + # Get FL results. + fl_results = run_tcga_testing_pipe( + ComputeLFCTester( + design_factors=design_factors, + lfc_mode=lfc_mode, + ref_levels=complete_ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + PQN_min_mu=PQN_min_mu, + ), + raw_data_path=data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Get the sample ids and permutations to reorder + sample_ids = pooled_dds.obs_names + sample_ids_fl = fl_results["sample_ids"] + # Get the permutation that sorts the sample ids + perm = np.argsort(sample_ids) + perm_fl = np.argsort(sample_ids_fl) + + # Get the initialization beta + fl_beta_init = fl_results[f"{lfc_mode}_beta_init"] + + # FL gene name by convergence type + fl_irls_genes = fl_results[f"{lfc_mode}_irls_genes"] + fl_PQN_genes = fl_results[f"{lfc_mode}_PQN_genes"] + fl_all_diverged_genes = fl_results[f"{lfc_mode}_all_diverged_genes"] + + # Get mu and beta param names + beta_param_name = fl_results["beta_param_name"] + mu_param_name = fl_results["mu_param_name"] + + # fl mu results + fl_mu_irls_converged = fl_results[f"{mu_param_name}_irls_converged"] + fl_mu_PQN_converged = fl_results[f"{mu_param_name}_PQN_converged"] + fl_mu_all_diverged = fl_results[f"{mu_param_name}_all_diverged"] + + # Reorder the mu results + fl_mu_irls_converged = fl_mu_irls_converged[perm_fl] + fl_mu_PQN_converged = fl_mu_PQN_converged[perm_fl] + fl_mu_all_diverged = fl_mu_all_diverged[perm_fl] + + # fl LFC results + fl_LFC_irls_converged = fl_results[f"{beta_param_name}_irls_converged"] + fl_LFC_PQN_converged = fl_results[f"{beta_param_name}_PQN_converged"] + fl_LFC_all_diverged = fl_results[f"{beta_param_name}_all_diverged"] + + # pooled beta init + pooled_beta_init = get_beta_init_pooled(pooled_dds) + + # pooled mu results + pooled_beta_param_name, pooled_mu_param_name = get_pooled_results( + pooled_dds, lfc_mode + ) + + pooled_mu_irls_converged = pooled_dds.layers[pooled_mu_param_name][ + :, pooled_dds.var_names.get_indexer(fl_irls_genes) + ] + pooled_mu_PQN_converged = pooled_dds.layers[pooled_mu_param_name][ + :, pooled_dds.var_names.get_indexer(fl_PQN_genes) + ] + pooled_mu_all_diverged = pooled_dds.layers[pooled_mu_param_name][ + :, pooled_dds.var_names.get_indexer(fl_all_diverged_genes) + ] + + # Reorder the mu results + pooled_mu_irls_converged = pooled_mu_irls_converged[perm] + pooled_mu_PQN_converged = pooled_mu_PQN_converged[perm] + pooled_mu_all_diverged = pooled_mu_all_diverged[perm] + + # pooled LFC results + pooled_LFC_irls_converged = ( + pooled_dds.varm[pooled_beta_param_name].loc[fl_irls_genes, :].to_numpy() + ) + pooled_LFC_PQN_converged = ( + pooled_dds.varm[pooled_beta_param_name].loc[fl_PQN_genes, :].to_numpy() + ) + pooled_LFC_all_diverged = ( + pooled_dds.varm[pooled_beta_param_name].loc[fl_all_diverged_genes, :].to_numpy() + ) + + #### ---- Check for the beta init ---- #### + + assert np.allclose( + fl_beta_init, + pooled_beta_init, + equal_nan=True, + ) + + #### ---- Check for the irls_converged ---- #### + + try: + assert np.allclose( + fl_LFC_irls_converged, + pooled_LFC_irls_converged, + equal_nan=True, + rtol=rtol_irls, + atol=atol_irls, + ) + + except AssertionError: + # This is likely due to the fact that beta values are small. + # We will check the relative error for the nll. + relative_LFC_mu_nll_test( + fl_mu_irls_converged, + pooled_mu_irls_converged, + fl_LFC_irls_converged, + pooled_LFC_irls_converged, + pooled_dds, + fl_irls_genes, + rtol=rtol_irls, + atol=atol_irls, + nll_rtol=nll_rtol, + nll_atol=nll_atol, + tolerated_failed_genes=0, + ) + + #### ---- Check for the PQN_converged ---- #### + + # For genes that have converged with the prox newton method, + # we check that the hat diagonals, mu LFC and LFC are the same + # for the FL and the pooled results. + # If it is not the case, we check the relative log likelihood is not + # too different. + # If that is not the case, we check that the relative optimization error wrt + # the beta init is not too different. + # We tolerate a few failed genes. + + relative_LFC_mu_nll_test( + fl_mu_PQN_converged, + pooled_mu_PQN_converged, + fl_LFC_PQN_converged, + pooled_LFC_PQN_converged, + pooled_dds, + fl_PQN_genes, + rtol=rtol_pqn, + atol=atol_pqn, + nll_rtol=nll_rtol, + nll_atol=nll_atol, + tolerated_failed_genes=tolerated_failed_genes, + ) + + #### ---- Check for the all_diverged ---- #### + + # We perform the same checks for the genes that have not converged as well. + + relative_LFC_mu_nll_test( + fl_mu_all_diverged, + pooled_mu_all_diverged, + fl_LFC_all_diverged, + pooled_LFC_all_diverged, + pooled_dds, + fl_all_diverged_genes, + rtol=rtol_pqn, + atol=atol_pqn, + nll_rtol=nll_rtol, + nll_atol=nll_atol, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +def relative_LFC_mu_nll_test( + fl_mu_LFC: np.ndarray, + pooled_mu_LFC: np.ndarray, + fl_LFC: np.ndarray, + pooled_LFC: np.ndarray, + pooled_dds: DeseqDataSet, + fl_genes: list[str], + rtol: float = 2e-2, + atol: float = 1e-3, + nll_rtol: float = 0.02, + nll_atol: float = 1e-3, + tolerated_failed_genes: int = 5, + irsl_mode: Literal["lfc", "mu_init"] = "lfc", +): + r"""Perform the relative error test for the LFC and mu values. + + This test checks that the relative error for the LFC and mu values is not too high. + If it is too high, we check the relative error for the nll. + + Parameters + ---------- + fl_mu_LFC: np.ndarray + The mu LFC from the FL results. + + pooled_mu_LFC: np.ndarray + The mu LFC from the pooled results. + + fl_LFC: np.ndarray + The LFC from the FL results. + + pooled_LFC: np.ndarray + The LFC from the pooled results. + + pooled_dds: DeseqDataSet + The pooled DeseqDataSet. + + fl_genes: list[str] + The genes that are not IRLS converged. + + rtol: float + The relative tolerance for the comparison of the LFC and mu values. + + atol: float + The absolute tolerance for the comparison of the LFC and mu values. + + nll_rtol: float + The relative tolerance for the comparison of the nll values. + + nll_atol: float + The absolute tolerance for the comparison of the nll values. + + tolerated_failed_genes: int + The number of tolerated failed genes. + + irsl_mode: Literal["lfc", "mu_init"] + The mode of the ComputeLFC algorithm. + """ + mu_LFC_error_tol = np.abs(pooled_mu_LFC) * rtol + atol + mu_LFC_abs_error = np.abs(fl_mu_LFC - pooled_mu_LFC) + LFC_error_tol = np.abs(pooled_LFC) * rtol + atol + LFC_abs_error = np.abs(fl_LFC - pooled_LFC) + + # We check that the relative errors are not too high. + to_check_mask = (mu_LFC_abs_error > mu_LFC_error_tol).any(axis=0) | ( + LFC_abs_error > LFC_error_tol + ).any(axis=1) + + if np.sum(to_check_mask) > 0: + print( + f"{to_check_mask.sum()} genes do not pass the relative error criterion." + f" Genes that do not pass the relative error criterion with tolerance 0.02:" + ) + print(fl_genes[to_check_mask]) + print("Corresponding top mu relative error") + mu_LFC_rel_error = np.abs((fl_mu_LFC - pooled_mu_LFC) / pooled_mu_LFC) + print(np.sort(mu_LFC_rel_error[:, to_check_mask])[-10:]) + print("Corresponding LFC relative error") + LFC_rel_error = np.abs((fl_LFC - pooled_LFC) / pooled_LFC) + print(LFC_rel_error[to_check_mask, :]) + # For the genes whose relative error is too high, + # We will start by checking the relative error for the nll. + to_check_genes = fl_genes[to_check_mask] + to_check_genes_index = pooled_dds.var_names.get_indexer(to_check_genes) + + counts = pooled_dds[:, to_check_genes_index].X + design = pooled_dds.obsm["design_matrix"].values + if irsl_mode == "lfc": + dispersions = pooled_dds[:, to_check_genes_index].varm["dispersions"] + else: + dispersions = pooled_dds[:, to_check_genes_index].varm["_MoM_dispersions"] + + # We compute the mu values for the FL and the pooled results. + size_factors = pooled_dds.obsm["size_factors"] + mu_fl = np.maximum( + size_factors[:, None] * np.exp(design @ fl_LFC[to_check_mask].T), + 0.5, + ) + mu_pooled = np.maximum( + size_factors[:, None] * np.exp(design @ pooled_LFC[to_check_mask].T), + 0.5, + ) + fl_nll = vec_loss( + counts, + design, + mu_fl, + dispersions, + ) + pooled_nll = vec_loss( + counts, + design, + mu_pooled, + dispersions, + ) + + # Note: here I test the nll and not the regularized NLL which is + # the real target of the optimization. However, this should not be + # an issue since we add only a small regularization and there is + # a bound on the beta values. + nll_error_tol = np.abs(pooled_nll) * nll_rtol + nll_atol + + failed_test_mask = (fl_nll - pooled_nll) > nll_error_tol + + # We identify the genes that do not pass the nll relative error criterion. + + if np.sum(failed_test_mask) > 0: + print( + f"{np.sum(failed_test_mask)} " + f"genes do not pass the nll relative error" + f" criterion with relative tolerance {nll_rtol} and absolute" + f"tolerance {nll_atol}." + ) + print("These genes are") + print(to_check_genes[failed_test_mask]) + print("Corresponding absolute error") + print(fl_nll[failed_test_mask] - pooled_nll[failed_test_mask]) + print("Corresponding error tolerance") + print(nll_error_tol[failed_test_mask]) + print("LFC pooled") + print(pooled_LFC[to_check_mask][failed_test_mask]) + print("LFC FL") + print(fl_LFC[to_check_mask][failed_test_mask]) + # We tolerate a few failed genes + + assert np.sum(failed_test_mask) <= tolerated_failed_genes + + +def get_pooled_results( + pooled_dds: DeseqDataSet, lfc_mode: Literal["lfc", "mu_init"] = "lfc" +) -> tuple[str, str]: + """Get the pooled results and set them in the pooled DeseqDataSet. + + Parameters + ---------- + pooled_dds : DeseqDataSet + The pooled DeseqDataSet. + + lfc_mode : Literal["lfc", "mu_init"] + The mode of the ComputeLFC algorithm. + If "lfc", the results are already computed. + If "mu_init", the results are computed using the pooled DeseqDataSet, as they + are not saved in the DeseqDataSet object. + + Returns + ------- + tuple[str, str] + The beta and mu parameter names , by which we can access them in the pooled + dataset. + + """ + if lfc_mode == "lfc": + return "LFC", "_mu_LFC" + + design_matrix = pooled_dds.obsm["design_matrix"].values + + mle_lfcs_, mu_, _, _ = pooled_dds.inference.irls( + counts=pooled_dds.X[:, pooled_dds.non_zero_idx], + size_factors=pooled_dds.obsm["size_factors"], + design_matrix=design_matrix, + disp=pooled_dds.varm["_MoM_dispersions"][pooled_dds.non_zero_idx], + min_mu=pooled_dds.min_mu, + beta_tol=pooled_dds.beta_tol, + ) + + pooled_dds.varm["_LFC_mu_hat"] = pd.DataFrame( + np.nan, + index=pooled_dds.var_names, + columns=pooled_dds.obsm["design_matrix"].columns, + ) + + pooled_dds.varm["_LFC_mu_hat"].update( + pd.DataFrame( + mle_lfcs_, + index=pooled_dds.non_zero_genes, + columns=pooled_dds.obsm["design_matrix"].columns, + ) + ) + + pooled_dds.layers["_mu_hat"] = np.full( + (pooled_dds.n_obs, pooled_dds.n_vars), np.nan + ) + pooled_dds.layers["_mu_hat"][:, pooled_dds.varm["non_zero"]] = mu_ + + return "_LFC_mu_hat", "_mu_hat" + + +def get_beta_init_pooled(pooled_dds: DeseqDataSet) -> np.ndarray: + """Get the initial beta values for the pooled DeseqDataSet. + + These initial beta values are used to initialize the optimization + of the log fold changes. + + Parameters + ---------- + pooled_dds : DeseqDataSet + The reference pooled DeseqDataSet from which we want to compute the initial + value of beta when computing log fold changes. + + Returns + ------- + np.ndarray + The initial beta values. + + """ + design_matrix = pooled_dds.obsm["design_matrix"].values + counts = pooled_dds.X[:, pooled_dds.non_zero_idx] + size_factors = pooled_dds.obsm["size_factors"] + + num_vars = design_matrix.shape[1] + X = design_matrix + if np.linalg.matrix_rank(X) == num_vars: + Q, R = np.linalg.qr(X) + y = np.log(counts / size_factors[:, None] + 0.1) + beta_init = solve(R, Q.T @ y) + else: # Initialise intercept with log base mean + beta_init = np.zeros(num_vars) + beta_init[0] = np.log(counts / size_factors[:, None]).mean() + + return beta_init.T diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_tester.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_tester.py new file mode 100644 index 0000000..5b7433f --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc_tester.py @@ -0,0 +1,425 @@ +"""A class to implement a unit tester class for ComputeLFC.""" +from pathlib import Path +from typing import Any +from typing import Literal + +import anndata as ad +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc import ComputeLFC +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_lfc.substeps import ( + AggConcatenateHandMu, +) +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_lfc.substeps import ( + LocGetLocalComputeLFCResults, +) +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +class ComputeLFCTester( + UnitTester, + ComputeLFC, + LocGetLocalComputeLFCResults, + AggConcatenateHandMu, +): + """A class to implement a unit test for ComputeLFC. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + lfc_mode : Literal["lfc", "mu_init"] + The mode of the IRLS algorithm. (default: ``"lfc"``). + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + min_mu : float + Lower threshold for mean expression values. (default: ``0.5``). + + beta_tol : float + Tolerance for the beta coefficients. (default: ``1e-8``). + + max_beta : float + Upper threshold for the beta coefficients. (default: ``30``). + + irls_num_iter : int + Number of iterations for the IRLS algorithm. (default: ``20``). + + joblib_backend : str + The backend to use for the IRLS algorithm. (default: ``"loky"``). + + num_jobs : int + Number of CPUs to use for parallelization. (default: ``8``). + + joblib_verbosity : int + Verbosity level for joblib. (default: ``3``). + + irls_batch_size : int + Batch size for the IRLS algorithm. (default: ``100``). + + PQN_c1 : float + Parameter for the Proximal Newton algorithm. (default: ``1e-4``) (which + catches the IRLS algorithm). This is a line search parameter for the Armijo + condition. + + PQN_ftol : float + Tolerance for the Proximal Newton algorithm. (default: ``1e-7``). + + PQN_num_iters_ls : int + Number of iterations for the line search in the Proximal Newton algorithm. + (default: ``20``). + + PQN_num_iters : int + Number of iterations for the Proximal Newton algorithm. (default: ``100``). + + PQN_min_mu : float + Lower threshold for the mean expression values in + the Proximal Newton algorithm. + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + init_local_states + A remote_data method which sets the local adata and the local gram matrix + from the reference_dds. + + compute_gram_matrix + A remote method which computes the gram matrix. + + set_gram_matrix + A remote_data method which sets the gram matrix in the local adata. + + build_compute_plan + Build the computation graph to run a ComputeLFC algorithm. + + save_irls_results + The method to save the IRLS results. + + + """ + + def __init__( + self, + design_factors: str | list[str], + lfc_mode: Literal["lfc", "mu_init"] = "lfc", + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + min_mu: float = 0.5, + beta_tol: float = 1e-8, + max_beta: float = 30, + irls_num_iter: int = 20, + joblib_backend: str = "loky", + num_jobs: int = 8, + joblib_verbosity: int = 0, + irls_batch_size: int = 100, + PQN_c1: float = 1e-4, + PQN_ftol: float = 1e-7, + PQN_num_iters_ls: int = 20, + PQN_num_iters: int = 100, + PQN_min_mu: float = 0.0, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + lfc_mode=lfc_mode, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + min_mu=min_mu, + beta_tol=beta_tol, + max_beta=max_beta, + irls_num_iter=irls_num_iter, + joblib_backend=joblib_backend, + num_jobs=num_jobs, + joblib_verbosity=joblib_verbosity, + irls_batch_size=irls_batch_size, + PQN_c1=PQN_c1, + PQN_ftol=PQN_ftol, + PQN_num_iters_ls=PQN_num_iters_ls, + PQN_num_iters=PQN_num_iters, + PQN_min_mu=PQN_min_mu, + ) + + #### Define hyper parameters #### + + self.min_mu = min_mu + self.beta_tol = beta_tol + self.max_beta = max_beta + + # Parameters of the IRLS algorithm + self.lfc_mode = lfc_mode + self.irls_num_iter = irls_num_iter + self.PQN_c1 = PQN_c1 + self.PQN_ftol = PQN_ftol + self.PQN_num_iters_ls = PQN_num_iters_ls + self.PQN_num_iters = PQN_num_iters + self.PQN_min_mu = PQN_min_mu + + #### Define job parallelization parameters #### + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Initialize the local states. + + This methods sets the local_adata and the local gram matrix. + from the reference_dds. + + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Not used. + + Returns + ------- + local_states : dict + Local states containing the local Gram matrix. + """ + + # Using super().init_local_states_from_opener(data_from_opener, shared_state) + # does not work, so we have to duplicate the code + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer + del self.local_adata.layers["_mu_hat"] + # This field is not saved in pydeseq2 but used in fedpyseq2 + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + # Get the local gram matrix for all genes + design_matrix = self.local_adata.obsm["design_matrix"].values + + return { + "local_gram_matrix": design_matrix.T @ design_matrix, + } + + @remote + @log_remote + def compute_gram_matrix(self, shared_states: list[dict]) -> dict: + """Compute the gram matrix. + + Parameters + ---------- + shared_states : list + List of shared states. + + Returns + ------- + dict + Dictionary containing the global gram matrix. + """ + + # Sum the local gram matrices + tot_gram_matrix = sum([state["local_gram_matrix"] for state in shared_states]) + # Share it with the centers + return { + "global_gram_matrix": tot_gram_matrix, + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_gram_matrix(self, data_from_opener: ad.AnnData, shared_state: Any): + """Set the gram matrix in the local adata. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state containing the global gram matrix. + + """ + self.local_adata.uns["_global_gram_matrix"] = shared_state["global_gram_matrix"] + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the gram matrix #### + + gram_matrix_shared_state, round_idx = aggregation_step( + aggregation_method=self.compute_gram_matrix, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Compute the global gram matrix.", + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.set_gram_matrix, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=gram_matrix_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Set the global gram matrix in the local adata.", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Perform fedIRLS #### + + local_states, round_idx = self.compute_lfc( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + lfc_mode=self.lfc_mode, + ) + + self.save_irls_results( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=False, + ) + + def save_irls_results( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Save the IRLS results. + + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The current round. + + clean_models: bool + If True, the models are cleaned. + + """ + ( + local_states, + local_irls_results_shared_states, + round_idx, + ) = local_step( + local_method=self.get_local_irls_results, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get the local IRLS results.", + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.concatenate_irls_outputs, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=local_irls_results_shared_states, + round_idx=round_idx, + description="Compute global IRLS inverse hat matrix and last nll.", + clean_models=clean_models, + ) diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py new file mode 100644 index 0000000..cfef7df --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py @@ -0,0 +1,230 @@ +"""Substeps for the ComputeLFC testing class, to aggregate information""" + + +import numpy as np +from anndata import AnnData +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data + + +class LocGetLocalComputeLFCResults: + """Get the local ComputeLFC results. + + Attributes + ---------- + local_adata: AnnData + The local AnnData. + + + Methods + ------- + get_local_irls_results + Get the local ComputeLFC results. + """ + + local_adata: AnnData + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_irls_results(self, data_from_opener: AnnData, shared_state: dict): + """Get the local ComputeLFC results. + + Parameters + ---------- + data_from_opener: AnnData + Not used. + + shared_state: dict + Not used. + + Returns + ------- + dict + The state to share to the server. + It contains the following fields: + - beta: ndarray or None + The current beta, of shape (n_non_zero_genes, n_params). + - mu: ndarray or None + The mu, of shape (n_obs, n_non_zero_genes). + - hat_diag: ndarray or None + The hat diagonal, of shape (n_obs, n_non_zero_genes). + - gene_names: list[str] + The names of the genes that are still active (non zero gene names + on the irls_mask). + - lfc_mode: str + The mode of the ComputeLFC algorithm. + For now, only "lfc" is supported. + - mu_param_name: str or None + The name of the mu parameter in the adata. + - beta_param_name: str + The name of the beta parameter in the adata. + + """ + mu_param_name = self.local_adata.uns["_irls_mu_param_name"] + beta_param_name = self.local_adata.uns["_irls_beta_param_name"] + lfc_mode = self.local_adata.uns["_lfc_mode"] + + irls_diverged_mask = self.local_adata.uns["_irls_diverged_mask"] + PQN_diverged_mask = self.local_adata.uns["_PQN_diverged_mask"] + + non_zero_genes_names = self.local_adata.var_names[ + self.local_adata.varm["non_zero"] + ] + non_zero_genes_mask = self.local_adata.varm["non_zero"] + + # Get the initial beta value + beta_init = self.local_adata.uns["_irls_beta_init"] + + # Get beta from the varm of the local adata + beta_dataframe = self.local_adata.varm[beta_param_name] + beta: np.ndarray | None = beta_dataframe.loc[non_zero_genes_names, :].to_numpy() + assert beta is not None + beta_irls_converged = beta[~irls_diverged_mask] + beta_PQN_converged = beta[irls_diverged_mask & ~PQN_diverged_mask] + beta_all_diverged = beta[irls_diverged_mask & PQN_diverged_mask] + + # Get mu from the layers of the local adata + if mu_param_name is not None: + mu: np.ndarray | None = self.local_adata.layers[mu_param_name][ + :, non_zero_genes_mask + ] + assert mu is not None + mu_irls_converged = mu[:, ~irls_diverged_mask] + mu_PQN_converged = mu[:, irls_diverged_mask & ~PQN_diverged_mask] + mu_all_diverged = mu[:, irls_diverged_mask & PQN_diverged_mask] + else: + mu_irls_converged = None + mu_PQN_converged = None + mu_all_diverged = None + + irls_genes = non_zero_genes_names[~irls_diverged_mask] + PQN_genes = non_zero_genes_names[irls_diverged_mask & ~PQN_diverged_mask] + all_diverged_genes = non_zero_genes_names[ + irls_diverged_mask & PQN_diverged_mask + ] + + # Get the sample ids of the local adata + sample_ids = self.local_adata.obs_names + + shared_state = { + "beta_param_name": beta_param_name, + "mu_param_name": mu_param_name, + "beta_irls_converged": beta_irls_converged, + "beta_PQN_converged": beta_PQN_converged, + "beta_all_diverged": beta_all_diverged, + "mu_irls_converged": mu_irls_converged, + "mu_PQN_converged": mu_PQN_converged, + "mu_all_diverged": mu_all_diverged, + "irls_genes": irls_genes, + "PQN_genes": PQN_genes, + "all_diverged_genes": all_diverged_genes, + "lfc_mode": lfc_mode, + "beta_init": beta_init, + "sample_ids": sample_ids, + } + return shared_state + + +class AggConcatenateHandMu: + """Mixin to concatenate the hat matrix and mu. + + Methods + ------- + concatenate_irls_outputs + Concatenate that hat and mu matrices, in order to save these outputs for + evaluation. + """ + + @remote + @log_remote + def concatenate_irls_outputs(self, shared_states: dict): + """Concatenate that hat and mu matrices. + + Parameters + ---------- + shared_states : list[dict] + The shared states. + It is a list of dictionaries containing the following + keys: + - beta: ndarray or None + The current beta, of shape (n_non_zero_genes, n_params). + - mu: ndarray or None + The mu, of shape (n_obs, n_non_zero_genes). + - hat_diag: ndarray or None + The hat diagonal, of shape (n_obs, n_non_zero_genes). + - gene_names: list[str] + The names of the genes that are still active (non zero gene names + on the irls_mask). + - lfc_mode: str + The mode of the ComputeLFC algorithm. + For now, only "lfc" is supported. + - mu_param_name: str or None + The name of the mu parameter in the adata. + - beta_param_name: str or None + The name of the beta parameter in the adata. + + + Returns + ------- + dict + A dictionary of results containing the following fields: + - beta_param_name: np.ndarray or None + The current beta, of shape (n_non_zero_genes, n_params). + - mu_param_name: np.ndarray or None + The mu, of shape (n_obs, n_non_zero_genes). + - {lfc_mode}_gene_names: list[str] + The names of the genes that are still active (non zero gene names) + """ + # Get the sample independent quantities + beta_irls_converged = shared_states[0]["beta_irls_converged"] + beta_PQN_converged = shared_states[0]["beta_PQN_converged"] + beta_all_diverged = shared_states[0]["beta_all_diverged"] + irls_genes = shared_states[0]["irls_genes"] + PQN_genes = shared_states[0]["PQN_genes"] + all_diverged_genes = shared_states[0]["all_diverged_genes"] + beta_param_name = shared_states[0]["beta_param_name"] + mu_param_name = shared_states[0]["mu_param_name"] + lfc_mode = shared_states[0]["lfc_mode"] + beta_init = shared_states[0]["beta_init"] + + # Concatenate mu + if mu_param_name is not None: + mu_irls_converged: np.ndarray | None = np.concatenate( + [state["mu_irls_converged"] for state in shared_states], axis=0 + ) + mu_PQN_converged: np.ndarray | None = np.concatenate( + [state["mu_PQN_converged"] for state in shared_states], axis=0 + ) + mu_all_diverged: np.ndarray | None = np.concatenate( + [state["mu_all_diverged"] for state in shared_states], axis=0 + ) + else: + mu_irls_converged = None + mu_PQN_converged = None + mu_all_diverged = None + + # Conctenate sample ids + sample_ids = np.concatenate( + [state["sample_ids"] for state in shared_states], axis=0 + ) + + self.results = { + "beta_param_name": beta_param_name, + "mu_param_name": mu_param_name, + f"{beta_param_name}_irls_converged": beta_irls_converged, + f"{beta_param_name}_PQN_converged": beta_PQN_converged, + f"{beta_param_name}_all_diverged": beta_all_diverged, + f"{mu_param_name}_irls_converged": mu_irls_converged, + f"{mu_param_name}_PQN_converged": mu_PQN_converged, + f"{mu_param_name}_all_diverged": mu_all_diverged, + f"{lfc_mode}_irls_genes": irls_genes, + f"{lfc_mode}_PQN_genes": PQN_genes, + f"{lfc_mode}_all_diverged_genes": all_diverged_genes, + f"{lfc_mode}_beta_init": beta_init, + "sample_ids": sample_ids, + } diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/test_compute_lfc.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/test_compute_lfc.py new file mode 100644 index 0000000..d5d9257 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/compute_lfc/test_compute_lfc.py @@ -0,0 +1,303 @@ +"""Unit tests for the compute_lfc module.""" + +import pytest + +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_lfc.compute_lfc_test_pipe import ( # noqa: E501 + pipe_test_compute_lfc, +) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_compute_lfc_small_genes( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if compute_lfc is working as expected. + + Note that the catching of IRLS is very simple here, as there are not enough + genes to observe significant differences in the log fold changes. + + The behaviour of the fed prox algorithm is tested on a self hosted runner. + + Moreover, we only test with the fisher scaling mode, as the other modes are + tested in the other tests, and perform less well in our tested datasets. + + We do not clip mu as this seems to yield better results. + + Parameters + ---------- + design_factors: str or list[str] + The design factors. + + continuous_factors: list[str] or None + The continuous factors. + + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + data_path=raw_data_path, + lfc_mode="lfc", + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_compute_lfc_small_samples( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if compute_lfc is working as expected.. + + This test focuses on a small number of samples, to see if the algorithm is working + as expected in the self hosted CI. Note that on a small number of samples, the + algorithm is less performant then when there are more samples (see the + test_compute_lfc_small_genes test). + This can be explained by the fact that the log likelihood + is somehow less smooth when there are few data points. + + Note that for a reason that is not clear, for IRLS converged genes, a tolerance + of 1e-5 is too hard (even if theoretically, the algorithm is pooled equivalent). + However, the results are still quite close to the pooled ones, and we do not + investigate this further for now. + + Parameters + ---------- + design_factors: Union[str, List[str]] + The design factors. + + continuous_factors: Union[str, List[str]] + The continuous factors. + + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + lfc_mode="lfc", + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + tolerated_failed_genes=5, + rtol_irls=1e-3, + atol_irls=1e-5, + nll_rtol=0.04, + nll_atol=1e-3, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_compute_lfc_luad( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if computing lfc is working as expected. + + This test focuses on a large number of samples and genes, to see if the algorithm + is working as expected in the self hosted CI. + + Note that a relative tolerance of 1e-3 is used, instead of the default 1e-5. The + reasons for which this is needed are not clear, but the results are still quite + close to the pooled ones, and we do not investigate this further for now. + + Parameters + ---------- + design_factors: Union[str, List[str]] + The design factors. + + continuous_factors: Union[str, List[str]] + The continuous factors. + + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + lfc_mode="lfc", + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + tolerated_failed_genes=5, + rtol_irls=1e-3, + atol_irls=1e-5, + nll_rtol=0.02, + nll_atol=1e-3, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_compute_lfc_paad( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if computing lfc is working as expected. + + This test focuses on a large number of samples and genes, to see if the algorithm + is working as expected in the self hosted CI. + + Note that we do not add CPE as a continuous factor, as it is not present in the + PAAD dataset. + + Note that a relative tolerance of 1e-2 is used, instead of the default 1e-5. The + reasons for which this is needed are not clear, but the results are still quite + close to the pooled ones, and we do not investigate this further for now. + + Parameters + ---------- + design_factors: Union[str, List[str]] + The design factors. + + continuous_factors: Union[str, List[str]] + The continuous factors. + + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc( + lfc_mode="lfc", + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-PAAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=0.0, + tolerated_failed_genes=5, + rtol_irls=1e-2, + atol_irls=1e-2, + nll_rtol=0.02, + nll_atol=1e-3, + ) diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions.py new file mode 100644 index 0000000..d5928cc --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions.py @@ -0,0 +1,792 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_MAP_dispersions import ( # noqa: E501 + ComputeMAPDispersions, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.utils_genewise_dispersions import ( # noqa: E501 + perform_dispersions_and_nll_relative_check, +) +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.pass_on_first_shared_state import ( + AggPassOnFirstSharedState, +) +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +def test_MAP_dispersions_on_small_genes_small_samples( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, +): + MAP_dispersions_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors, tolerated_failed_genes", + [ + ("stage", None, 0), + (["stage", "gender"], None, 0), + (["stage", "gender", "CPE"], ["CPE"], 1), + ], +) +def test_MAP_dispersions_on_small_samples_on_self_hosted_fast( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, + tolerated_failed_genes, +): + MAP_dispersions_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +def test_MAP_dispersions_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, +): + MAP_dispersions_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + ], +) +def test_MAP_dispersions_paad_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, +): + MAP_dispersions_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-PAAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def MAP_dispersions_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, + rtol: float = 0.02, + atol: float = 1e-3, + nll_rtol: float = 0.02, + nll_atol: float = 1e-3, + tolerated_failed_genes: int = 0, +): + """Perform a unit test for the MAP dispersions. + + Starting with the same dispersion trend curve as the reference dataset, compute MAP + dispersions and compare the results with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + + rtol: float + The relative tolerance for between the FL and pooled dispersions. + + atol: float + The absolute tolerance for between the FL and pooled dispersions. + + nll_rtol: float + The relative tolerance for between the FL and pooled likelihoods, in the + case of a failed dispersion check. + + nll_atol: float + The absolute tolerance for between the FL and pooled likelihoods, in the + case of a failed dispersion check. + + tolerated_failed_genes: int + The number of genes that are allowed to fail the test. + + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = ( + local_processed_data_path / "centers_data" / "tcga" / experiment_id + ) + # Get FL results. + fl_results = run_tcga_testing_pipe( + MAPDispersionsTester( + design_factors=design_factors, + ref_levels=complete_ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + local_processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Tests that the MAP dispersions are close to the pooled ones, or if not, + # that the adjusted log likelihood is close or better + + fl_dispersions = fl_results["MAP_dispersions"] + perform_dispersions_and_nll_relative_check( + fl_dispersions, + pooled_dds, + dispersions_param_name="MAP_dispersions", + prior_reg=True, + rtol=rtol, + atol=atol, + nll_rtol=nll_rtol, + nll_atol=nll_atol, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +class MAPDispersionsTester( + UnitTester, ComputeMAPDispersions, AggPassOnResults, AggPassOnFirstSharedState +): + """A class to implement a unit test for the MAP dispersions. + + Note that this test checks the MAP dispersions BEFORE filtering. + The filtering is also done in the ComputeMAPDispersions class, but is + tested separately. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + min_disp : float + Lower threshold for dispersion parameters. (default: ``1e-8``). + + max_disp : float + Upper threshold for dispersion parameters. + Note: The threshold that is actually enforced is max(max_disp, len(counts)). + (default: ``10``). + + grid_batch_size : int + The number of genes to put in each batch for local parallel processing. + (default: ``100``). + + grid_depth : int + The number of grid interval selections to perform (if using GridSearch). + (default: ``3``). + + grid_length : int + The number of grid points to use for the grid search (if using GridSearch). + (default: ``100``). + + num_jobs : int + The number of jobs to use for local parallel processing in MLE tasks. + (default: ``8``). + + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + init_local_states + A remote_data method which copies the reference dds to the local state + and adds the non_zero mask. + It sets the max_disp to the maximum of the reference dds number of samples + and the max_disp parameter. + It returns a dictionary with the number of samples in + the "num_samples" field. + + sum_num_samples + A remote method which computes the total number of samples to set max_disp. + It returns a dictionary with the total number of samples in the + "tot_num_samples" field. + + set_max_disp + A remote_data method which sets max_disp using the total number of samples in + the study. + It returns a dictionary with the fitted dispersions in the "fitted_dispersions" + field and the prior variance of the dispersions in the "prior_disp_var" field. + + get_MAP_dispersions + A remote_data method which gets the filtered dispersions. + It returns a dictionary with the MAP dispersions in the "MAP_dispersions" field. + + create_trend_curve_fitting_shared_state + A method which creates the trend curve fitting shared state from reference. + + save_MAP_dispersions + A method which saves the MAP dispersions. + + build_compute_plan + A method which builds the computation graph to test the computation of the MAP + dispersions. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + min_disp: float = 1e-8, + max_disp: float = 10.0, + grid_batch_size: int = 250, + grid_depth: int = 3, + grid_length: int = 100, + num_jobs=8, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + min_disp=min_disp, + max_disp=max_disp, + grid_batch_size=grid_batch_size, + grid_depth=grid_depth, + grid_length=grid_length, + num_jobs=num_jobs, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + #### Define hyper parameters #### + + self.min_disp = min_disp + self.max_disp = max_disp + self.grid_batch_size = grid_batch_size + self.grid_depth = grid_depth + self.grid_length = grid_length + self.num_jobs = num_jobs + + # Add layers to save + self.layers_to_save_on_disk = {"local_adata": ["_mu_hat"], "refit_adata": None} + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to test the computation of the MAP dispersions. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Create trend curve fitting shared state #### + + ( + local_states, + trend_curve_shared_state, + round_idx, + ) = self.create_trend_curve_fitting_shared_state( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Fit MAP dispersions with MLE #### + + local_states, round_idx = self.fit_MAP_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + shared_state=trend_curve_shared_state, + round_idx=round_idx, + clean_models=clean_models, + ) + + self.save_MAP_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + def create_trend_curve_fitting_shared_state( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Create the trend curve fitting shared state from reference. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of train data nodes. + + aggregation_node : AggregationNode + Aggregation node. + + local_states : dict + Dictionary of local states. + + round_idx : int + Round index. + + clean_models : bool + Whether to clean the models after the computation. + + Returns + ------- + local_states : dict + Dictionary of local states. + + trend_curve_shared_state : dict + Trend curve shared state. It is a dictionary with a field + "fitted_dispersion" containing the fitted dispersions from the trend curve, + and a field "prior_disp_var" containing the prior variance + of the dispersions. + + round_idx : int + Round index. + + + """ + #### Load reference dataset as local_adata and set local states #### + + local_states, init_shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + ### Aggregation step to compute the total number of samples ### + shared_state, round_idx = aggregation_step( + aggregation_method=self.sum_num_samples, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=init_shared_states, + description="Get the total number of samples.", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, max_disp_shared_states, round_idx = local_step( + local_method=self.set_max_disp, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_state, + aggregation_id=aggregation_node.organization_id, + description="Compute max_disp and forward fitted dispersions", + round_idx=round_idx, + clean_models=clean_models, + ) + + trend_curve_shared_state, round_idx = aggregation_step( + aggregation_method=self.pass_on_shared_state, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=max_disp_shared_states, + description="Pass on the shared state", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, trend_curve_shared_state, round_idx + + def save_MAP_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Save the MAP dispersions. + + This method gets the MAP dispersions from the local states and saves them in the + results field of the local states. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of train data nodes. + + aggregation_node : AggregationNode + Aggregation node. + + local_states : dict + Dictionary of local states. + + round_idx : int + Round index. + + clean_models : bool + Whether to clean the models after the computation. + Note that the last step is not cleaned. + + """ + # ---- Get the filtered dispersions ---- # + local_states, shared_states, round_idx = local_step( + local_method=self.get_MAP_dispersions, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get filtered dispersions.", + round_idx=round_idx, + clean_models=clean_models, + ) + + # ---- Save the MAP dispersions ---- # + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Save the MAP dispersions.", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds to the local state and add the non_zero mask. + + Set the max_disp to the maximum of the reference dds number of samples and the + max_disp parameter. + + Returns a dictionary with the number of samples in the "num_samples" field. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with "fitted_dispersions" and "prior_disp_var" keys. + + Returns + ------- + dict + The number of samples in the "num_samples" field. + """ + + self.local_adata = self.local_reference_dds.copy() + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + self.local_adata.uns["max_disp"] = max( + self.max_disp, self.local_reference_dds.n_obs + ) + + return { + "num_samples": self.local_adata.n_obs, + } + + @remote + @log_remote + def sum_num_samples(self, shared_states): + """Compute the total number of samples to set max_disp. + + Parameters + ---------- + shared_states : list + List of initial shared states copied from the reference adata. + + Returns + ------- + dict + + """ + tot_num_samples = np.sum([state["num_samples"] for state in shared_states]) + return {"tot_num_samples": tot_num_samples} + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_max_disp(self, data_from_opener: ad.AnnData, shared_state: Any) -> dict: + """Set max_disp using the total number of samples in the study. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with "fitted_dispersions" and "prior_disp_var" keys. + """ + + self.local_adata.uns["max_disp"] = max( + self.max_disp, shared_state["tot_num_samples"] + ) + + return { + "fitted_dispersions": self.local_adata.varm["fitted_dispersions"], + "prior_disp_var": self.local_adata.uns["prior_disp_var"], + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_MAP_dispersions(self, data_from_opener, shared_state): + """Get the filtered dispersions. + + Returns + ------- + dict + A dictionary with the MAP dispersions in the "MAP_dispersions" field. + """ + return {"MAP_dispersions": self.local_adata.varm["MAP_dispersions"]} diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions_filtering.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions_filtering.py new file mode 100644 index 0000000..0f51199 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_MAP_dispersions_filtering.py @@ -0,0 +1,469 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_MAP_dispersions.substeps import ( # noqa: E501 + LocFilterMAPDispersions, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_MAP_dispersions_filtering_on_small_genes_small_samples( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + MAP_dispersions_filtering_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_MAP_dispersions_filtering_on_small_samples_on_self_hosted_fast( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + MAP_dispersions_filtering_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_MAP_dispersions_filtering_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + MAP_dispersions_filtering_testing_pipe( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def MAP_dispersions_filtering_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for MAP dispersions filtering. + + Starting with the same genewise and MAP dispersions as the reference dataset, + filter dispersion outliers and compare the results with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=None, + ) + + reference_data_path = ( + local_processed_data_path / "centers_data" / "tcga" / experiment_id + ) + # Get FL results. + fl_results = run_tcga_testing_pipe( + DispersionsFilteringTester( + design_factors=design_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + local_processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Tests that the final dispersions are equal to the pooled ones + assert np.allclose( + pooled_dds.varm["dispersions"], fl_results["dispersions"], equal_nan=True + ) + + +class DispersionsFilteringTester( + UnitTester, + LocFilterMAPDispersions, + AggPassOnResults, +): + """A class to implement a unit test for MAP dispersions filtering. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + num_jobs : int + The number of jobs to use for local parallel processing in MLE tasks. + (default: ``8``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + build_compute_plan + Build the computation graph to perform and save MAP filtering. + + init_local_states + Remote_data method. + Copy the reference dds to the local state and add the non_zero mask. + It returns a shared state with the MAP dispersions. + + get_filtered_dispersions + Remote_data method. + Get the filtered dispersions from the local adata. + + save_filtered_dispersions + Save the filtered dispersions. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + num_jobs=8, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + num_jobs=num_jobs, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + #### Define hyper parameters #### + self.num_jobs = num_jobs + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to perform and save MAP filtering. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Filter MAP dispersions #### + + local_states, _, round_idx = local_step( + local_method=self.filter_outlier_genes, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_states[0], + aggregation_id=aggregation_node.organization_id, + description="Filter MAP dispersions.", + round_idx=round_idx, + clean_models=clean_models, + ) + + self.save_filtered_dispersions( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds to the local state and add the non_zero mask. + + It returns a shared state with the MAP dispersions. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with "fitted_dispersions" and "prior_disp_var" keys. + """ + + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + + return { + "MAP_dispersions": self.local_adata.varm["MAP_dispersions"], + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_filtered_dispersions(self, data_from_opener, shared_state): + """Get the filtered dispersions. + + Returns + ------- + dict + A dictionary with the filtered dispersions. + """ + return {"dispersions": self.local_adata.varm["dispersions"]} + + def save_filtered_dispersions( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Save the filtered dispersions. + + This method gets the filtered dispersions from one of the local states and saves + them in the results format. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of train data nodes. + + aggregation_node : AggregationNode + Aggregation node. + + local_states : dict + Dictionary of local states. + + round_idx : int + Round index. + + clean_models : bool + Whether to clean the models after the computation. + Note that the last step is not cleaned. + + """ + # ---- Get the filtered dispersions ---- # + local_states, shared_states, round_idx = local_step( + local_method=self.get_filtered_dispersions, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get filtered dispersions.", + round_idx=round_idx, + clean_models=clean_models, + ) + + # ---- Save the filtered dispersions ---- # + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Save filtered dispersions.", + round_idx=round_idx, + clean_models=False, + ) + + return local_states, round_idx diff --git a/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_trend_curve.py b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_trend_curve.py new file mode 100644 index 0000000..9bd77e6 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_lfc_dispersions/test_trend_curve.py @@ -0,0 +1,417 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_dispersion_prior import ( # noqa: E501 + ComputeDispersionPrior, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.pass_on_first_shared_state import ( + AggPassOnFirstSharedState, +) +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), # TODO this case fails to converge in PyDESeq2 + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +def test_trend_curve( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors, + continuous_factors, +): + """Perform a unit test for the trend curve. + + Starting with the same genewise dispersions as the reference DeseqDataSet, fit a + parametric trend curve, compute the prior dispersion and compare the results with + the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + """ + + trend_curve_testing_pipe( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + +def trend_curve_testing_pipe( + data_path, + processed_data_path, + assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for the trend curve. + + Starting with the same genewise dispersions as the reference DeseqDataSet, fit a + parametric trend curve, compute the prior dispersion and compare the results with + the reference. + + Parameters + ---------- + data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + # Get FL results. + fl_results = run_tcga_testing_pipe( + TrendCurveTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=data_path, + processed_data_path=processed_data_path, + assets_directory=assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + if pooled_dds.uns["disp_function_type"] == "mean": + assert fl_results["disp_function_type"] == "mean" + np.allclose( + fl_results["mean_disp"], + pooled_dds.uns["mean_disp"], + equal_nan=False, + rtol=0.02, + ) + + return + + assert np.allclose( + fl_results["trend_coeffs"], + pooled_dds.uns["trend_coeffs"], + equal_nan=False, + rtol=0.02, + ) + + # Test the dispersion prior + assert np.allclose( + fl_results["prior_disp_var"], + pooled_dds.uns["prior_disp_var"], + equal_nan=True, + rtol=0.02, + ) + + +class TrendCurveTester( + UnitTester, ComputeDispersionPrior, AggPassOnFirstSharedState, AggPassOnResults +): + """A class to implement a unit test for the trend curve. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + min_disp : float + Lower threshold for dispersion parameters. (default: ``1e-8``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + build_compute_plan + Build the computation graph to compute the trend curve. + + init_local_states + A remote_data method. + Copy the reference dds to the local state and add the non_zero mask. + Return the genewise dispersions as the shared state. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + min_disp: float = 1e-8, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + min_disp=min_disp, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + #### Define hyper parameters #### + self.min_disp = min_disp + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to compute the trend curve. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Pass the first shared state as the genewise dispersions shared states #### + + genewise_dispersions_shared_state, round_idx = aggregation_step( + aggregation_method=self.pass_on_shared_state, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Pass the genewise dispersions shared state", + clean_models=clean_models, + ) + + #### Fit dispersion trends #### + ( + local_states, + dispersion_trend_shared_state, + round_idx, + ) = self.compute_dispersion_prior( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + genewise_dispersions_shared_state=genewise_dispersions_shared_state, + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Save shared state #### + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=[dispersion_trend_shared_state], + round_idx=round_idx, + description="Save the first shared state", + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds to the local state and add the non_zero mask. + + Returns a shared state with the genewise dispersions. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + shared_state : Any + Shared state with a "genewise_dispersions" key. + + Returns + ------- + dict + A dictionary containing the genewise dispersions. + + """ + + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + -1 + ] + + return {"genewise_dispersions": self.local_adata.varm["genewise_dispersions"]} diff --git a/tests/unit_tests/deseq2_core/deseq2_stats/__init__.py b/tests/unit_tests/deseq2_core/deseq2_stats/__init__.py new file mode 100644 index 0000000..f162acc --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_stats/__init__.py @@ -0,0 +1 @@ +"""Module to test the deseq2 stats module.""" diff --git a/tests/unit_tests/deseq2_core/deseq2_stats/test_compute_padj.py b/tests/unit_tests/deseq2_core/deseq2_stats/test_compute_padj.py new file mode 100644 index 0000000..f6850fb --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_stats/test_compute_padj.py @@ -0,0 +1,552 @@ +import pickle as pkl +from pathlib import Path +from typing import Literal + +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from pydeseq2.ds import DeseqStats +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_stats.compute_padj import ( + ComputeAdjustedPValues, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors, independent_filter", + [ + ("stage", None, True), + (["stage", "gender"], None, True), + (["stage", "gender", "CPE"], ["CPE"], True), + ("stage", None, False), + (["stage", "gender"], None, False), + (["stage", "gender", "CPE"], ["CPE"], False), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_compute_adj_small_genes( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + independent_filter: bool, +): + compute_padj_testin_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + independent_filter=independent_filter, + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize("independent_filter", [True, False]) +def test_compute_adj_small_samples( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + independent_filter: bool, +): + compute_padj_testin_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + independent_filter=independent_filter, + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize("independent_filter", [True, False]) +def test_compute_adj_on_self_hosted( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + independent_filter: bool, +): + compute_padj_testin_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + independent_filter=independent_filter, + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.local +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize("independent_filter", [True, False]) +def test_compute_adj_on_local( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + independent_filter: bool, +): + compute_padj_testin_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + independent_filter=independent_filter, + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def compute_padj_testin_pipe( + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + independent_filter=True, + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for Wald tests + Starting with the dispersions and LFC as the reference DeseqDataSet, perform Wald + tests and compare the results with the reference. + Parameters + ---------- + raw_data_path: Path + The path to the root data. + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + independent_filter: bool + Whether to use independent filtering to correct the p-values trend. + (default: ``True``). + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + backend: str + The backend to use. Either "subprocess" or "docker". + only_two_centers: bool + If true, restrict the data to two centers. + design_factors: str or list + The design factors to use. + continuous_factors: list or None + The continuous factors to use. + ref_levels: dict or None + The reference levels of the design factors. + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_dir = processed_data_path / "pooled_data" / "tcga" / experiment_id + + pooled_dds_file_path = pooled_dds_file_dir / f"{pooled_dds_file_name}.pkl" + + # Get FL results. + fl_results = run_tcga_testing_pipe( + ComputeAdjustedPValuesTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + independent_filter=independent_filter, + reference_pooled_data_path=pooled_dds_file_dir, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Avoid outliers not refit warning + pooled_dds.refit_cooks = False + + # Run pydeseq2 Wald tests on the reference data + pooled_ds = DeseqStats( + pooled_dds, cooks_filter=False, independent_filter=independent_filter + ) + pooled_ds.summary() + + assert np.allclose( + fl_results["padj"], + pooled_ds.padj, + equal_nan=True, + rtol=0.02, + ) + + +class ComputeAdjustedPValuesTester( + UnitTester, ComputeAdjustedPValues, AggPassOnResults +): + """A class to implement a for p-value adjustment. + + # TODO merge the method for running the Wald test on the reference data with the + # TODO equivalent method for testing cooks. + + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + contrast : list or None + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'ref_level']``. + Names must correspond to the metadata data passed to the DeseqDataSet. + E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B' compared + to 'condition A'. + For continuous variables, the last two strings should be left empty, e.g. + ``['measurement', '', ''].`` + If None, the last variable from the design matrix is chosen + as the variable of interest, and the reference level is picked alphabetically. + (default: ``None``). + independent_filter : bool + Whether to use independent filtering to correct the p-values trend. + (default: ``True``). + alpha : float + P-value and adjusted p-value significance threshold (usually 0.05). + (default: ``0.05``). + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + reference_pooled_data_path : str or Path + The path to the reference pooled data. This is used to build the reference + DeseqStats object. This is only used for testing purposes, and should not be + used in a real-world scenario. + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + contrast: list[str] | None = None, + independent_filter: bool = True, + alpha: float = 0.05, + lfc_null: float = 0.0, + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] + | None = None, + joblib_backend: str = "loky", + irls_batch_size: int = 100, + num_jobs: int = 8, + joblib_verbosity: int = 3, + reference_data_path: str | Path | None = None, + reference_pooled_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + independent_filter=independent_filter, + alpha=alpha, + contrast=contrast, + lfc_null=lfc_null, + alt_hypothesis=alt_hypothesis, + joblib_backend=joblib_backend, + irls_batch_size=joblib_backend, + num_jobs=num_jobs, + joblib_verbosity=joblib_verbosity, + reference_data_path=reference_data_path, + reference_pooled_data_path=reference_pooled_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + self.lfc_null = lfc_null + self.alt_hypothesis = alt_hypothesis + + self.reference_pooled_data_path = reference_pooled_data_path + self.reference_dds_ref_level = reference_dds_ref_level + + self.independent_filter = independent_filter + self.alpha = alpha + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a DESeq2 pipe. + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, empty_shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the reference wald test results #### + wald_test_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_run_wald_test_on_ground_truth, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=empty_shared_states, + description="Run Wald test on ground truth", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the adusted p-values #### + ( + local_states, + round_idx, + ) = self.compute_adjusted_p_values( + train_data_nodes, + aggregation_node, + local_states, + wald_test_shared_state, + round_idx, + clean_models=clean_models, + ) + + local_states, wald_test_shared_states, round_idx = local_step( + local_method=self.get_results_from_local_adata, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get results to share from the local centers", + round_idx=round_idx, + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=wald_test_shared_states, + description="Save the first shared state.", + round_idx=round_idx, + clean_models=False, + ) + + @remote + @log_remote + def agg_run_wald_test_on_ground_truth(self, shared_states: dict) -> dict: + """Run Wald tests on the reference data. + Parameters + ---------- + shared_states : dict + Shared states. Not used. + Returns + ------- + shared_states : dict + Shared states. The new shared state contains the Wald test results on the + pooled reference. + """ + pooled_dds_file_name = get_ground_truth_dds_name(self.reference_dds_ref_level) + + pooled_dds_file_path = ( + Path(self.reference_pooled_data_path) / f"{pooled_dds_file_name}.pkl" + ) + + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Avoid outliers not refit warning + pooled_dds.refit_cooks = False # TODO to change after refit cooks implemented. + + # Run pydeseq2 Wald tests on the reference data + pooled_ds = DeseqStats(pooled_dds, cooks_filter=False, independent_filter=False) + pooled_ds.run_wald_test() + + return { + "p_values": pooled_ds.p_values.to_numpy(), + "wald_statistics": pooled_ds.statistics, + "wald_se": pooled_ds.SE, + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_results_from_local_adata( + self, + data_from_opener, + shared_state: dict | None, + ) -> dict: + """ + Get the results to share from the local states. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict, optional + Not used. + + Returns + ------- + dict + Shared state containing the adjusted p-values, the p-values, the Wald + standard errors, and the Wald statistics. + + """ + + shared_state = { + varm_key: self.local_adata.varm[varm_key] + for varm_key in ["padj", "p_values", "wald_se", "wald_statistics"] + } + return shared_state diff --git a/tests/unit_tests/deseq2_core/deseq2_stats/test_cooks_filtering.py b/tests/unit_tests/deseq2_core/deseq2_stats/test_cooks_filtering.py new file mode 100644 index 0000000..9e87459 --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_stats/test_cooks_filtering.py @@ -0,0 +1,544 @@ +import pickle as pkl +from pathlib import Path +from typing import Any +from typing import Literal + +import anndata as ad +import numpy as np +import pandas as pd +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from pydeseq2.ds import DeseqStats +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_stats.cooks_filtering import CooksFiltering +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "cooks_filter, refit_cooks, design_factors, continuous_factors", + [ + (True, False, "stage", None), + (True, True, "stage", None), + (False, False, "stage", None), + (False, True, "stage", None), + (True, False, ["stage", "gender"], None), + (True, False, ["stage", "gender", "CPE"], ["CPE"]), + (True, True, ["stage", "gender", "CPE"], ["CPE"]), + (False, False, ["stage", "gender", "CPE"], ["CPE"]), + (False, True, ["stage", "gender", "CPE"], ["CPE"]), + ], +) +def test_cooks_filtering_small_genes( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + cooks_filter: bool, + refit_cooks: bool, + design_factors: str | list[str], + continuous_factors: list[str] | None, +): + cooks_filtering_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + cooks_filter=cooks_filter, + refit_cooks=refit_cooks, + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize("cooks_filter", [True, False]) +def test_cooks_filtering_on_self_hosted( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + cooks_filter: bool, +): + cooks_filtering_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + cooks_filter=cooks_filter, + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.local +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize("cooks_filter", [True, False]) +def test_cooks_filtering_on_local( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + cooks_filter: bool, +): + cooks_filtering_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + cooks_filter=cooks_filter, + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def cooks_filtering_testing_pipe( # TODO we will have to add a case when cooks + # TODO are refitted + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + cooks_filter: bool = True, + refit_cooks: bool = False, + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for Wald tests + Starting with the dispersions and LFC as the reference DeseqDataSet, perform Wald + tests and compare the results with the reference. + Parameters + ---------- + raw_data_path: Path + The path to the root data. + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + cooks_filter: bool + Whether to filter Cook's distances. (default: ``True``). + refit_cooks: bool + Whether to refit Cook's outliers. (default: ``False``). + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + backend: str + The backend to use. Either "subprocess" or "docker". + only_two_centers: bool + If true, restrict the data to two centers. + design_factors: str or list + The design factors to use. + continuous_factors: list or None + The continuous factors to use. + ref_levels: dict or None + The reference levels of the design factors. + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name( + reference_dds_ref_level, refit_cooks=refit_cooks + ) + pooled_dds_file_dir = processed_data_path / "pooled_data" / "tcga" / experiment_id + + pooled_dds_file_path = pooled_dds_file_dir / f"{pooled_dds_file_name}.pkl" + + # Get FL results. + fl_results = run_tcga_testing_pipe( + CooksFilteringTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=complete_ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + cooks_filter=cooks_filter, + refit_cooks=refit_cooks, + reference_pooled_data_path=pooled_dds_file_dir, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + refit_cooks=refit_cooks, + ) + + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Avoid outliers not refit warning + if not refit_cooks: + pooled_dds.refit_cooks = False + + # Run pydeseq2 Wald tests on the reference data + pooled_ds = DeseqStats( + pooled_dds, cooks_filter=cooks_filter, independent_filter=False + ) + pooled_ds.run_wald_test() + if cooks_filter: + pooled_ds._cooks_filtering() + + assert np.allclose( + fl_results["p_values"], + pooled_ds.p_values, + equal_nan=True, + rtol=0.02, + ) + + +class CooksFilteringTester(UnitTester, CooksFiltering, AggPassOnResults): + """A class to implement a unit test for Wald tests. + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + contrast : list or None + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'ref_level']``. + Names must correspond to the metadata data passed to the DeseqDataSet. + E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B' compared + to 'condition A'. + For continuous variables, the last two strings should be left empty, e.g. + ``['measurement', '', ''].`` + If None, the last variable from the design matrix is chosen + as the variable of interest, and the reference level is picked alphabetically. + (default: ``None``). + cooks_filter : bool + Whether to filter Cook's distances. (default: ``True``). + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + reference_pooled_data_path : str or Path + The path to the reference pooled data. This is used to build the reference + DeseqStats object. This is only used for testing purposes, and should not be + used in a real-world scenario. + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + min_mu : float + The minimum value of mu. (default: ``0.5``). Needed to compute the + Cook's distances. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + contrast: list[str] | None = None, + cooks_filter: bool = True, + refit_cooks: bool = False, + lfc_null: float = 0.0, + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] + | None = None, + joblib_backend: str = "loky", + irls_batch_size: int = 100, + num_jobs: int = 8, + joblib_verbosity: int = 3, + reference_data_path: str | Path | None = None, + reference_pooled_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + min_mu: float = 0.5, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + cooks_filter=cooks_filter, + refit_cooks=refit_cooks, + contrast=contrast, + lfc_null=lfc_null, + alt_hypothesis=alt_hypothesis, + joblib_backend=joblib_backend, + irls_batch_size=irls_batch_size, + num_jobs=num_jobs, + joblib_verbosity=joblib_verbosity, + reference_data_path=reference_data_path, + reference_pooled_data_path=reference_pooled_data_path, + reference_dds_ref_level=reference_dds_ref_level, + min_mu=min_mu, + ) + + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + self.lfc_null = lfc_null + self.alt_hypothesis = alt_hypothesis + + self.reference_pooled_data_path = reference_pooled_data_path + self.reference_dds_ref_level = reference_dds_ref_level + + self.cooks_filter = cooks_filter + self.refit_cooks = refit_cooks + + self.min_mu = min_mu + self.layers_to_save_on_disk = {"local_adata": ["cooks"], "refit_adata": None} + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a DESeq2 pipe. + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, empty_shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the reference wald test results #### + wald_test_shared_state, round_idx = aggregation_step( + aggregation_method=self.agg_run_wald_test_on_ground_truth, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=empty_shared_states, + description="Run Wald test on ground truth", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the cooks dispersions + if self.cooks_filter: + local_states, wald_test_shared_state, round_idx = self.cooks_filtering( + train_data_nodes, + aggregation_node, + local_states, + wald_test_shared_state, + round_idx, + clean_models=clean_models, + ) + + #### Save the first shared state #### + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=[wald_test_shared_state], + description="Save the first shared state", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds to the local state and add the non_zero mask. + + Also sets the total number of samples in the uns attribute. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + shared_state : Any + Shared state. Not used. + """ + pooled_dds_file_name = get_ground_truth_dds_name( + self.reference_dds_ref_level, refit_cooks=self.refit_cooks + ) + + pooled_dds_file_path = ( + Path(self.reference_pooled_data_path) / f"{pooled_dds_file_name}.pkl" + ) + + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + counts_by_lvl = pooled_dds.obsm["design_matrix"].value_counts() + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + if "replace_cooks" in self.local_adata.layers.keys(): + del self.local_adata.layers["replace_cooks"] + + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + self.local_adata.uns["tot_num_samples"] = pooled_dds.n_obs + self.local_adata.uns["num_replicates"] = pd.Series(counts_by_lvl.values) + self.local_adata.obs["cells"] = [ + np.argwhere(counts_by_lvl.index == tuple(design))[0, 0] + for design in self.local_adata.obsm["design_matrix"].values + ] + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + return {} + + @remote + @log_remote + def agg_run_wald_test_on_ground_truth(self, shared_states: dict) -> dict: + """Run Wald tests on the reference data. + + Parameters + ---------- + shared_states : dict + Shared states. Not used. + + Returns + ------- + shared_states : dict + Shared states. The new shared state contains the Wald test results on the + pooled reference. + + """ + pooled_dds_file_name = get_ground_truth_dds_name( + self.reference_dds_ref_level, refit_cooks=self.refit_cooks + ) + + pooled_dds_file_path = ( + Path(self.reference_pooled_data_path) / f"{pooled_dds_file_name}.pkl" + ) + + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Avoid outliers not refit warning + if not self.refit_cooks: + pooled_dds.refit_cooks = False + + # Run pydeseq2 Wald tests on the reference data + pooled_ds = DeseqStats(pooled_dds, cooks_filter=False, independent_filter=False) + pooled_ds.run_wald_test() + + return { + "p_values": pooled_ds.p_values, + "wald_statistics": pooled_ds.statistics, + "wald_se": pooled_ds.SE, + } diff --git a/tests/unit_tests/deseq2_core/deseq2_stats/test_wald_tests.py b/tests/unit_tests/deseq2_core/deseq2_stats/test_wald_tests.py new file mode 100644 index 0000000..f3c71af --- /dev/null +++ b/tests/unit_tests/deseq2_core/deseq2_stats/test_wald_tests.py @@ -0,0 +1,535 @@ +import pickle as pkl +from pathlib import Path +from typing import Any +from typing import Literal + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from pydeseq2.ds import DeseqStats +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_stats.wald_tests import RunWaldTests +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import build_contrast +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors, contrast", + [ + ("stage", None, None), + (["stage", "gender"], None, ["stage", "Advanced", "Non-advanced"]), + (["stage", "gender", "CPE"], ["CPE"], ["CPE", "", ""]), + (["stage", "gender", "CPE"], ["CPE"], ["gender", "female", "male"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_wald_tests_contrasts_on_small_genes( + design_factors, + continuous_factors, + contrast, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + wald_tests_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + contrast=contrast, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors, alt_hypothesis", + [ + ("stage", None, "greaterAbs"), + (["stage", "gender"], None, "lessAbs"), + (["stage", "gender", "CPE"], ["CPE"], "greater"), + (["stage", "gender", "CPE"], ["CPE"], "less"), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_wald_tests_alt_on_small_genes( + design_factors, + continuous_factors, + alt_hypothesis, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + wald_tests_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + alt_hypothesis=alt_hypothesis, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_wald_tests_on_small_samples_on_self_hosted_fast( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + wald_tests_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors, contrast", + [ + ("stage", None, None), + (["stage", "gender"], None, ["stage", "Advanced", "Non-advanced"]), + (["stage", "gender", "CPE"], ["CPE"], ["CPE", "", ""]), + ], +) +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_wald_tests_on_self_hosted_slow( + design_factors, + continuous_factors, + contrast, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + wald_tests_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + contrast=contrast, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def wald_tests_testing_pipe( + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + contrast: list[str] | None = None, + lfc_null: float = 0.0, + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for Wald tests + + Starting with the dispersions and LFC as the reference DeseqDataSet, perform Wald + tests and compare the results with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + contrast: list or None + The contrast to use. + + lfc_null: float + The null hypothesis for the LFC. + + alt_hypothesis: str or None + The alternative hypothesis. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + # Get FL results. + fl_results = run_tcga_testing_pipe( + WaldTestTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + contrast=contrast, + lfc_null=lfc_null, + alt_hypothesis=alt_hypothesis, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # Avoid outliers not refit warning + pooled_dds.refit_cooks = False + + # Run pydeseq2 Wald tests on the reference data + pooled_ds = DeseqStats( + pooled_dds, + cooks_filter=False, + independent_filter=False, + contrast=contrast, + lfc_null=lfc_null, + alt_hypothesis=alt_hypothesis, + ) + pooled_ds.run_wald_test() + + # Show max of the absolute difference between the results + assert np.allclose( + fl_results["p_values"], + pooled_ds.p_values, + equal_nan=True, + rtol=0.02, + ) + + # Test the dispersion prior + assert np.allclose( + fl_results["wald_statistics"], + pooled_ds.statistics, + equal_nan=True, + rtol=0.02, + ) + + assert np.allclose( + fl_results["wald_se"], + pooled_ds.SE, + equal_nan=True, + rtol=0.02, + ) + + +class WaldTestTester(UnitTester, RunWaldTests, AggPassOnResults): + """A class to implement a unit test for Wald tests. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + contrast : list or None + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'ref_level']``. + Names must correspond to the metadata data passed to the DeseqDataSet. + E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B' compared + to 'condition A'. + For continuous variables, the last two strings should be left empty, e.g. + ``['measurement', '', ''].`` + If None, the last variable from the design matrix is chosen + as the variable of interest, and the reference level is picked alphabetically. + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + init_local_states + A remote_data method to copy the reference dds to the local state, add the + non_zero mask, and set the contrast in the uns attribute. + + build_compute_plan + Build the computation graph to test Wald test computations. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + contrast: list[str] | None = None, + lfc_null: float = 0.0, + alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] + | None = None, + joblib_backend: str = "loky", + irls_batch_size: int = 100, + num_jobs: int = 8, + joblib_verbosity: int = 3, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + contrast=contrast, + lfc_null=lfc_null, + alt_hypothesis=alt_hypothesis, + joblib_backend=joblib_backend, + irls_batch_size=irls_batch_size, + num_jobs=num_jobs, + joblib_verbosity=joblib_verbosity, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + self.lfc_null = lfc_null + self.alt_hypothesis = alt_hypothesis + + self.layers_to_save_on_disk = {"local_adata": ["cooks"], "refit_adata": None} + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to test Wald test computations. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, _, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Run Wald tests #### + local_states, wald_shared_state, round_idx = self.run_wald_tests( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=[wald_shared_state], + description="Save first shared state", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states(self, data_from_opener: ad.AnnData, shared_state: Any): + """Copy the reference dds to the local state and add the non_zero mask. + + Set the contrast in the uns attribute. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + shared_state : Any + Shared state with a "genewise_dispersions" key. + + """ + + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + self.local_adata.uns["contrast"] = build_contrast( + design_factors=self.design_factors, + design_columns=self.local_adata.obsm["design_matrix"].columns, + continuous_factors=self.continuous_factors, + contrast=self.contrast, + ) diff --git a/tests/unit_tests/deseq2_core/test_cooks_distances.py b/tests/unit_tests/deseq2_core/test_cooks_distances.py new file mode 100644 index 0000000..d1e7f09 --- /dev/null +++ b/tests/unit_tests/deseq2_core/test_cooks_distances.py @@ -0,0 +1,538 @@ +import pickle as pkl +from pathlib import Path + +import pandas as pd +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.compute_cook_distance import ComputeCookDistances +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates import ( # noqa: E501 + GetNumReplicates, +) +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_cooks_distances_on_small_genes( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + """ + Test Cook's distances on a small number of genes. + + Note that the first subcase is particularly important, as it tests the case where + the number of replicates for all levels of the design is greater than 3. + """ + cooks_distances_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_cooks_distances_on_small_samples_on_self_hosted_fast( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Test Cook's distances on a small number of samples on a self hosted runner. + + This test is quite important for the (["stage", "gender", "CPE"], ["CPE"]) + subset, as it tests the case where the number of replicates is less than 3 + for all levels of the design, and we enter in this specific case of the computation + of the trimmed variance. + """ + cooks_distances_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_cooks_distances_on_self_hosted_slow( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """ + Test Cook's distances on a self hosted runner. + + This test is particularly important for the (["stage", "gender", "CPE"], ["CPE"]) + subcase, as it tests the case where some levels of the design have less than + 3 replicates while other have more. This means that only part of the levels are + taken into account into the computation of the trimmed variance. + + """ + cooks_distances_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def cooks_distances_testing_pipe( + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for Cook's distances. + + Starting with the same counts as the reference dataset, compute sCook's distances + and compare the results with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + + fl_results = run_tcga_testing_pipe( + CooksDistanceTester( + design_factors=design_factors, + ref_levels=complete_ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + df_cooks_distance = fl_results["cooks_distance_df"] + pooled_dds_cooks_df = pooled_dds.to_df("cooks") + + df_cooks_distance = df_cooks_distance.loc[pooled_dds_cooks_df.index] + + pd.testing.assert_frame_equal(pooled_dds_cooks_df, df_cooks_distance) + + +class CooksDistanceTester( + UnitTester, + ComputeCookDistances, + GetNumReplicates, +): + """A class to implement a unit test for Cook's distances. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + min_mu : float + The minimum value of the mean expression to be considered as a valid gene. + Used to compute the hat diagonals matrix, which is a required input for the + computation of Cook's distances. (default: 0.5). + + trimmed_mean_num_iter: int + The number of iterations to use when computing the trimmed mean + in a federated way, i.e. the number of dichotomy steps. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + min_mu: float = 0.5, + trimmed_mean_num_iter: int = 40, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + min_mu=min_mu, + trimmed_mean_num_iter=trimmed_mean_num_iter, + ) + + self.min_mu = 0.5 + self.trimmed_mean_num_iter = trimmed_mean_num_iter + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the number of replicates #### + + local_states, round_idx = self.get_num_replicates( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + ( + local_states, + dispersion_for_cook_shared_state, + round_idx, + ) = self.compute_cook_distance( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ) + + self.save_cook_distance( + train_data_nodes, + aggregation_node, + local_states, + dispersion_for_cook_shared_state, + round_idx, + clean_models=clean_models, + ) + + def save_cook_distance( + self, + train_data_nodes, + aggregation_node, + local_states, + dispersion_for_cook_shared_state, + round_idx, + clean_models, + ): + """ + Save Cook's distances. It must be used in the main pipeline while testing. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + aggregation_node : AggregationNode + The aggregation node. + local_states : dict + Local states. Required to propagate intermediate results. + dispersion_for_cook_shared_state : dict + Shared state with the dispersion values for Cook's distances, in a + "cooks_dispersions" key. + round_idx : int + Index of the current round. + clean_models : bool + Whether to clean the models after the computation. + + Returns + ------- + local_states : dict + Local states. Required to propagate intermediate results. + shared_states : dict + Shared states. Required to propagate intermediate results. + round_idx : int + The updated round index. + + """ + local_states, shared_states, round_idx = local_step( + local_method=self.get_loc_cook_distance, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=dispersion_for_cook_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Get cook distances", + round_idx=round_idx, + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.agg_cook_distance, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Aggregate cook distances", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_loc_cook_distance(self, data_from_opener, shared_state: dict) -> dict: + """ + Save Cook's distances. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Not used. + + Returns + ------- + dict + Dictionary with the following key: + - cooks_distance_df: Cook's distances in a df + + """ + return {"cooks_distance_df": self.local_adata.to_df("cooks")} + + @remote + @log_remote + def agg_cook_distance(self, shared_states: list[dict]): + """ + Aggregate Cook's distances. + + Parameters + ---------- + shared_states : list[dict] + List of shared states with the following key: + - cooks_distance_df: Cook's distances in a df + + """ + cooks_distance_df = pd.concat( + [shared_state["cooks_distance_df"] for shared_state in shared_states], + axis=0, + ) + self.results = {"cooks_distance_df": cooks_distance_df} diff --git a/tests/unit_tests/deseq2_core/test_design_matrices.py b/tests/unit_tests/deseq2_core/test_design_matrices.py new file mode 100644 index 0000000..5e5e634 --- /dev/null +++ b/tests/unit_tests/deseq2_core/test_design_matrices.py @@ -0,0 +1,514 @@ +import pickle as pkl +from pathlib import Path + +import pandas as pd +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from pandas.testing import assert_frame_equal +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.build_design_matrix import BuildDesignMatrix +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_build_design_on_small_genes( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + build_design_matrix_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_build_design_on_small_samples_on_self_hosted_fast( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + build_design_matrix_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_build_design_on_self_hosted_slow( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + build_design_matrix_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.local +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_build_design_on_local( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + build_design_matrix_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=("stage", "Advanced"), + continuous_factors=continuous_factors, + ) + + +def build_design_matrix_testing_pipe( + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = ("stage", "Advanced"), +): + """Perform a unit test for Wald tests + Starting with the dispersions and LFC as the reference DeseqDataSet, perform Wald + tests and compare the results with the reference. + Parameters + ---------- + raw_data_path: Path + The path to the root data. + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + backend: str + The backend to use. Either "subprocess" or "docker". + only_two_centers: bool + If true, restrict the data to two centers. + design_factors: str or list + The design factors to use. + continuous_factors: list[str] or None + The continuous factors amongst the design factors + ref_levels: dict or None + The reference levels of the design factors. + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_dir = processed_data_path / "pooled_data" / "tcga" / experiment_id + + pooled_dds_file_path = pooled_dds_file_dir / f"{pooled_dds_file_name}.pkl" + + # Get FL results. + fl_results = run_tcga_testing_pipe( + DesignMatrixTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=complete_ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + reference_pooled_data_path=pooled_dds_file_dir, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + fl_design_matrix = fl_results["design_matrix"] + + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + assert_frame_equal( + fl_design_matrix.sort_index().reindex( + pooled_dds.obsm["design_matrix"].columns, axis=1 + ), + pooled_dds.obsm["design_matrix"].sort_index(), + check_dtype=False, + ) + + +class DesignMatrixTester( + UnitTester, + BuildDesignMatrix, +): + """A class to implement a unit test for the design matrix curve. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + build_compute_plan + Build the computation graph to run the creation of the design matrix and + the means to save it. + + save_design_matrix + Save the design matrix computed using Substra. + + get_local_design_matrix + Get the local design matrix from the obsm of the AnnData. + + concatenate_design_matrices + Concatenate design matrices together for registration. + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ) + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Build design matrices #### + + local_states, _, round_idx = self.build_design_matrix( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + print("Finished building design matrices.") + + #### Check the design matrices #### + + _ = self.save_design_matrix( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ) + + def save_design_matrix( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Check the design matrix. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The current round + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + The local states, containing what is needed for evaluation. + + round_idx: int + The updated round. + + """ + # ---- Concatenate local design matrices ----# + + local_states, shared_states, round_idx = local_step( + local_method=self.get_local_design_matrix, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + round_idx=round_idx, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get the local design matrix", + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.concatenate_design_matrices, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Concatenating local design matrices", + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_design_matrix( + self, + data_from_opener, + shared_state, + ) -> dict: + """ + Get the local design matrix from the obsm of the AnnData. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict or None + Should be None. + + Returns + ------- + dict + A dictionary containing the local design matrix in the + "local_design_matrix" field. + """ + + return {"local_design_matrix": self.local_adata.obsm["design_matrix"]} + + @remote + @log_remote + def concatenate_design_matrices(self, shared_states): + """Concatenate design matrices together for registration. + + Parameters + ---------- + shared_states : list + List of design matrices from training nodes. + + """ + tot_design = pd.concat( + [state["local_design_matrix"] for state in shared_states] + ) + self.results = {"design_matrix": tot_design} diff --git a/tests/unit_tests/deseq2_core/test_refit_cooks.py b/tests/unit_tests/deseq2_core/test_refit_cooks.py new file mode 100644 index 0000000..62a94cc --- /dev/null +++ b/tests/unit_tests/deseq2_core/test_refit_cooks.py @@ -0,0 +1,757 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions import DESeq2LFCDispersions +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates import ( # noqa: E501 + GetNumReplicates, +) +from fedpydeseq2.core.deseq2_core.replace_outliers import ReplaceCooksOutliers +from fedpydeseq2.core.deseq2_core.replace_refitted_values import ReplaceRefittedValues +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_refit_cooks_on_small_genes( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + refit_cooks_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_refit_cooks_on_small_samples_on_self_hosted_fast( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + refit_cooks_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_refit_cooks_on_self_hosted_slow( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + refit_cooks_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def refit_cooks_testing_pipe( + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for outlier imputation. + + Starting with the same counts as the reference dataset, replace count values of + Cooks outliers and compare with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + + fl_results = run_tcga_testing_pipe( + RefitOutliersTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + fl_adata = fl_results["local_adatas"][0] + + min_disp = fl_results["min_disp"] + max_disp = fl_adata.uns["max_disp"] + max_beta = fl_results["max_beta"] + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + init_pooled_dds = pooled_dds.copy() + pooled_dds.refit_cooks = True + pooled_dds.refit() + + # Check that the refitted size factors are the same as the original values + for adata in fl_results["local_adatas"]: + assert np.allclose( + adata.obsm["size_factors"], + init_pooled_dds[adata.obs_names].obsm["size_factors"], + equal_nan=True, + ) + + # Check that refitted genes are the same + assert np.array_equal( + fl_adata.varm["refitted"], pooled_dds.varm["refitted"], equal_nan=True + ) + + # Check that replaceable samples are the same + for adata in fl_results["local_adatas"]: + assert np.array_equal( + adata.obsm["replaceable"], + pooled_dds[adata.obs_names].obsm["replaceable"], + equal_nan=True, + ) + + if pooled_dds.varm["refitted"].sum() > 0: + # Check that the genewise dispersions of refitted genes have changed + # Except if out of bounds + fl_genewise_on_refitted = fl_adata.varm["genewise_dispersions"][ + fl_adata.varm["refitted"] + ] + init_pooled_genewise_on_refitted = init_pooled_dds.varm["genewise_dispersions"][ + pooled_dds.varm["refitted"] + ] + different = fl_genewise_on_refitted != init_pooled_genewise_on_refitted + out_of_bounds = (fl_genewise_on_refitted < min_disp + 1e-8) | ( + fl_genewise_on_refitted > max_disp - 1e-8 + ) + if out_of_bounds.sum() > 0: + print("Genewise dispersions out of bounds") + print(fl_adata.var_names[fl_adata.varm["refitted"]][out_of_bounds]) + print(fl_genewise_on_refitted[out_of_bounds]) + print(init_pooled_genewise_on_refitted[out_of_bounds]) + assert np.all(different | out_of_bounds) + + # Check that the MAP dispersions of refitted genes have changed + # Except if out of bounds + fl_MAP_on_refitted = fl_adata.varm["MAP_dispersions"][fl_adata.varm["refitted"]] + init_pooled_MAP_on_refitted = init_pooled_dds.varm["MAP_dispersions"][ + pooled_dds.varm["refitted"] + ] + different = fl_MAP_on_refitted != init_pooled_MAP_on_refitted + out_of_bounds = (fl_MAP_on_refitted < min_disp + 1e-8) | ( + fl_MAP_on_refitted > max_disp - 1e-8 + ) + if out_of_bounds.sum() > 0: + print("MAP dispersions out of bounds") + print(fl_adata.var_names[fl_adata.varm["refitted"]][out_of_bounds]) + print(fl_MAP_on_refitted[out_of_bounds]) + print(init_pooled_MAP_on_refitted[out_of_bounds]) + assert np.all(different | out_of_bounds) + + # Check that the LFC of refitted genes have changed + # Except if out of bounds + fl_LFC_on_refitted = fl_adata.varm["LFC"][fl_adata.varm["refitted"]].to_numpy() + init_pooled_LFC_on_refitted = init_pooled_dds.varm["LFC"][ + pooled_dds.varm["refitted"] + ].to_numpy() + different = fl_LFC_on_refitted != init_pooled_LFC_on_refitted + out_of_bounds = (fl_LFC_on_refitted < -max_beta + 1e-8) | ( + fl_LFC_on_refitted > max_beta - 1e-8 + ) + if out_of_bounds.sum() > 0: + print("LFC out of bounds") + print(fl_adata.var_names[fl_adata.varm["refitted"]][out_of_bounds]) + print(fl_LFC_on_refitted[out_of_bounds]) + print(init_pooled_LFC_on_refitted[out_of_bounds]) + assert np.all(different | out_of_bounds) + + if (~pooled_dds.varm["refitted"]).sum() > 0: + # Check that the genewise dispersions of non-refitted genes have not changed + np.testing.assert_array_almost_equal( + fl_adata.varm["genewise_dispersions"][~fl_adata.varm["refitted"]], + pooled_dds.varm["genewise_dispersions"][~pooled_dds.varm["refitted"]], + decimal=6, + ) + + # Check that the MAP dispersions of non-refitted genes have not changed + np.testing.assert_array_almost_equal( + fl_adata.varm["MAP_dispersions"][~fl_adata.varm["refitted"]], + pooled_dds.varm["MAP_dispersions"][~pooled_dds.varm["refitted"]], + decimal=6, + ) + + # Check that the LFC of non-refitted genes have not changed + np.testing.assert_array_almost_equal( + fl_adata.varm["LFC"][~fl_adata.varm["refitted"]].values, + pooled_dds.varm["LFC"][~pooled_dds.varm["refitted"]].values, + decimal=6, + ) + + +class RefitOutliersTester( + UnitTester, + ReplaceCooksOutliers, + DESeq2LFCDispersions, + GetNumReplicates, + ReplaceRefittedValues, +): + """A class to implement a unit test for outlier imputation. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + min_replicates : int + The minimum number of replicates for a gene to be considered for outlier + replacement. (default: 7). + + trimmed_mean_num_iter: int + The number of iterations to use when computing the trimmed mean + in a federated way, i.e. the number of dichotomy steps. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + min_replicates: int = 7, + min_disp: float = 1e-8, + max_disp: float = 10.0, + grid_batch_size: int = 250, + grid_depth: int = 3, + grid_length: int = 100, + num_jobs=8, + min_mu: float = 0.5, + beta_tol: float = 1e-8, + max_beta: float = 30, + irls_num_iter: int = 20, + joblib_backend: str = "loky", + joblib_verbosity: int = 0, + irls_batch_size: int = 100, + independent_filter: bool = True, + alpha: float = 0.05, + PQN_c1: float = 1e-4, + PQN_ftol: float = 1e-7, + PQN_num_iters_ls: int = 20, + PQN_num_iters: int = 100, + PQN_min_mu: float = 0.0, + trimmed_mean_num_iter: int = 40, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + min_replicates=min_replicates, + min_disp=min_disp, + max_disp=max_disp, + grid_batch_size=grid_batch_size, + grid_depth=grid_depth, + grid_length=grid_length, + num_jobs=num_jobs, + min_mu=min_mu, + beta_tol=beta_tol, + max_beta=max_beta, + irls_num_iter=irls_num_iter, + joblib_backend=joblib_backend, + joblib_verbosity=joblib_verbosity, + irls_batch_size=irls_batch_size, + independent_filter=independent_filter, + alpha=alpha, + PQN_c1=PQN_c1, + PQN_ftol=PQN_ftol, + PQN_num_iters_ls=PQN_num_iters_ls, + PQN_num_iters=PQN_num_iters, + PQN_min_mu=PQN_min_mu, + trimmed_mean_num_iter=trimmed_mean_num_iter, + ) + + #### Define hyper parameters #### + + self.min_disp = min_disp + self.max_disp = max_disp + self.grid_batch_size = grid_batch_size + self.grid_depth = grid_depth + self.grid_length = grid_length + self.min_mu = min_mu + self.beta_tol = beta_tol + self.max_beta = max_beta + + # Parameters of the IRLS algorithm + self.irls_num_iter = irls_num_iter + self.min_replicates = min_replicates + self.PQN_c1 = PQN_c1 + self.PQN_ftol = PQN_ftol + self.PQN_num_iters_ls = PQN_num_iters_ls + self.PQN_num_iters = PQN_num_iters + self.PQN_min_mu = PQN_min_mu + + # Parameters for the trimmed mean + self.trimmed_mean_num_iter = trimmed_mean_num_iter + + #### Define job parallelization parameters #### + self.num_jobs = num_jobs + self.joblib_verbosity = joblib_verbosity + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + # Save on disk + self.layers_to_save_on_disk = { + "local_adata": ["cooks"], + "refit_adata": ["cooks"], + } + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : List[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, init_shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + ### Aggregation step to compute the total number of samples ### + shared_state, round_idx = aggregation_step( + aggregation_method=self.sum_num_samples_and_gram_matrix, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=init_shared_states, + description="Get the total number of samples.", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, max_disp_shared_states, round_idx = local_step( + local_method=self.set_tot_num_samples_and_gram_matrix, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_state, + aggregation_id=aggregation_node.organization_id, + description="Set the total number of samples locally.", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the number of replicates #### + + local_states, round_idx = self.get_num_replicates( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + local_states, gram_features_shared_state, round_idx = self.replace_outliers( + train_data_nodes, + aggregation_node, + local_states, + cooks_shared_state=None, + round_idx=round_idx, + clean_models=clean_models, + ) + + ##### Run the pipelline in refit mode ##### + + local_states, round_idx = self.run_deseq2_lfc_dispersions( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + gram_features_shared_states=gram_features_shared_state, + round_idx=round_idx, + clean_models=clean_models, + refit_mode=True, + ) + + # Replace values in the main ``local_adata`` object + local_states, round_idx = self.replace_refitted_values( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + # Concatenate local_adatas and store results + # 1 - Local centers return their local adatas in a local state + local_states, shared_states, round_idx = local_step( + local_method=self.get_adatas, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Return local and refit adatas in a shared state", + round_idx=round_idx, + clean_models=clean_models, + ) + + # 2 - Return concatenated adatas + aggregation_step( + aggregation_method=self.return_adatas, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Return the lists of local and refit adatas", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds and add the total number of samples. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with a "num_samples" key. + """ + + self.local_adata = self.local_reference_dds.copy() + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + self.local_adata.uns["n_params"] = self.local_reference_dds.obsm[ + "design_matrix" + ].shape[1] + design_matrix = self.local_adata.obsm["design_matrix"].values + + return { + "num_samples": self.local_adata.n_obs, + "local_gram_matrix": design_matrix.T @ design_matrix, + } + + @remote + @log_remote + def sum_num_samples_and_gram_matrix(self, shared_states): + """Compute the total number of samples to set max_disp and gram matrix. + + Parameters + ---------- + shared_states : List + List of initial shared states copied from the reference adata. + + Returns + ------- + dict + + """ + tot_num_samples = np.sum([state["num_samples"] for state in shared_states]) + # Sum the local gram matrices + tot_gram_matrix = sum([state["local_gram_matrix"] for state in shared_states]) + return { + "tot_num_samples": tot_num_samples, + "global_gram_matrix": tot_gram_matrix, + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_tot_num_samples_and_gram_matrix( + self, data_from_opener: ad.AnnData, shared_state: Any + ): + """Set the total number of samples in the study, and the gram matrix. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with "fitted_dispersions" and "prior_disp_var" keys. + """ + + self.local_adata.uns["tot_num_samples"] = shared_state["tot_num_samples"] + self.local_adata.uns["max_disp"] = max( + self.max_disp, shared_state["tot_num_samples"] + ) + # TODO this is not used but the key is expected + self.local_adata.uns["mean_disp"] = None + self.local_adata.uns["_global_gram_matrix"] = shared_state["global_gram_matrix"] + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_adatas( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Return adatas. + + Used for testing only.""" + + return { + "local_adata": self.local_adata, + "refit_adata": self.refit_adata, + "min_disp": self.min_disp, + "max_beta": self.max_beta, + } + + @remote + @log_remote + def return_adatas( + self, + shared_states: list, + ): + """Return the adatas as lists. Used for testing only.""" + + local_adatas = [shared_state["local_adata"] for shared_state in shared_states] + + refit_adatas = [shared_state["refit_adata"] for shared_state in shared_states] + + self.results = { + "local_adatas": local_adatas, + "refit_adatas": refit_adatas, + "min_disp": shared_states[0]["min_disp"], + "max_beta": shared_states[0]["max_beta"], + } diff --git a/tests/unit_tests/deseq2_core/test_replace_cooks.py b/tests/unit_tests/deseq2_core/test_replace_cooks.py new file mode 100644 index 0000000..5b90301 --- /dev/null +++ b/tests/unit_tests/deseq2_core/test_replace_cooks.py @@ -0,0 +1,576 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pandas as pd +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.compute_cook_distance import ComputeCookDistances +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_genewise_dispersions.get_num_replicates import ( # noqa: E501 + GetNumReplicates, +) +from fedpydeseq2.core.deseq2_core.replace_outliers import ReplaceCooksOutliers +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.levels import make_reference_and_fl_ref_levels +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_replace_cooks_on_small_genes( + design_factors, + continuous_factors, + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + replace_cooks_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_replace_cooks_on_small_samples_on_self_hosted_fast( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + replace_cooks_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +@pytest.mark.parametrize( + "design_factors, continuous_factors", + [ + ("stage", None), + (["stage", "gender"], None), + (["stage", "gender", "CPE"], ["CPE"]), + ], +) +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_replace_cooks_on_self_hosted_slow( + design_factors, + continuous_factors, + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + replace_cooks_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + ) + + +def replace_cooks_testing_pipe( + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for outlier imputation. + + Starting with the same counts as the reference dataset, replace count values of + Cooks outliers and compare with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + continuous_factors: list or None + The continuous factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + complete_ref_levels, reference_dds_ref_level = make_reference_and_fl_ref_levels( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=continuous_factors, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + + fl_results = run_tcga_testing_pipe( + ReplaceOutliersTester( + design_factors=design_factors, + continuous_factors=continuous_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + continuous_factors=continuous_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + pooled_dds._replace_outliers() + + if hasattr(pooled_dds, "counts_to_refit"): + pooled_imputed_counts = pd.DataFrame( + pooled_dds.counts_to_refit.X, + index=pooled_dds.counts_to_refit.obs_names, + columns=pooled_dds.counts_to_refit.var_names, + ) + else: + pooled_imputed_counts = pd.DataFrame(index=pooled_dds.obs_names, columns=[]) + + # FL counts should be the restriction of the pooled counts to the + # refitted genes and samples (the fl results are restricted to + # genes that must be refitted and not only the replaced genes) + fl_imputed_counts = fl_results["imputed_counts"] + pooled_imputed_counts = pooled_imputed_counts.loc[ + fl_imputed_counts.index, fl_imputed_counts.columns + ] + + pd.testing.assert_frame_equal(pooled_imputed_counts, fl_imputed_counts) + + +class ReplaceOutliersTester( + UnitTester, + ReplaceCooksOutliers, + ComputeCookDistances, + GetNumReplicates, +): + """A class to implement a unit test for outlier imputation. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + min_replicates : int + The minimum number of replicates for a gene to be considered for outlier + replacement. (default: 7). + + min_mu : float + The minimum value of mu used for Cooks distance computation. + + trimmed_mean_num_iter: int + The number of iterations to use when computing the trimmed mean + in a federated way, i.e. the number of dichotomy steps. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + min_replicates: int = 7, + min_mu: float = 0.5, + trimmed_mean_num_iter: int = 40, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + min_replicates=min_replicates, + trimmed_mean_num_iter=trimmed_mean_num_iter, + ) + + self.min_replicates = min_replicates + self.min_mu = min_mu + self.trimmed_mean_num_iter = trimmed_mean_num_iter + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : List[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, init_shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + ### Aggregation step to compute the total number of samples ### + shared_state, round_idx = aggregation_step( + aggregation_method=self.sum_num_samples, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=init_shared_states, + description="Get the total number of samples.", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, max_disp_shared_states, round_idx = local_step( + local_method=self.set_tot_num_samples, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_state, + aggregation_id=aggregation_node.organization_id, + description="Set the total number of samples locally.", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the number of replicates #### + + local_states, round_idx = self.get_num_replicates( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + # Compute Cooks distances and impute outliers + local_states, shared_state, round_idx = self.compute_cook_distance( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ) + + local_states, refit_features_shared_states, round_idx = self.replace_outliers( + train_data_nodes, + aggregation_node, + local_states, + cooks_shared_state=shared_state, + round_idx=round_idx, + clean_models=clean_models, + ) + + ##### Save results ##### + + local_states, shared_states, round_idx = local_step( + local_method=self.loc_return_imputed_counts, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Return imputed counts in a shared state", + round_idx=round_idx, + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.agg_merge_imputed_counts, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Merge the lists of local imputed counts", + round_idx=round_idx, + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def loc_return_imputed_counts( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Return the imputed counts as a DataFrame in a shared state. + + Used for testing only.""" + + return { + "imputed_counts": pd.DataFrame( + self.refit_adata.X, + index=self.refit_adata.obs_names, + columns=self.refit_adata.var_names, + ) + } + + @remote + @log_remote + def agg_merge_imputed_counts( + self, + shared_states: dict, + ): + """Merge the imputed counts. Used for testing only.""" + + imputed_counts = pd.concat( + [shared_state["imputed_counts"] for shared_state in shared_states], + axis=0, + ) + self.results = {"imputed_counts": imputed_counts} + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds and add the total number of samples. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with a "num_samples" key. + """ + + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + self.local_adata.varm["non_zero"] = self.local_reference_dds.varm["non_zero"] + self.local_adata.uns["n_params"] = self.local_reference_dds.obsm[ + "design_matrix" + ].shape[1] + + return {"num_samples": self.local_adata.n_obs} + + @remote + @log_remote + def sum_num_samples(self, shared_states): + """Compute the total number of samples to set max_disp. + + Parameters + ---------- + shared_states : List + List of initial shared states copied from the reference adata. + + Returns + ------- + dict + + """ + tot_num_samples = np.sum([state["num_samples"] for state in shared_states]) + return {"tot_num_samples": tot_num_samples} + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_tot_num_samples(self, data_from_opener: ad.AnnData, shared_state: Any): + """Set the total number of samples in the study. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with "fitted_dispersions" and "prior_disp_var" keys. + """ + + self.local_adata.uns["tot_num_samples"] = shared_state["tot_num_samples"] diff --git a/tests/unit_tests/deseq2_core/test_save_pipeline_results.py b/tests/unit_tests/deseq2_core/test_save_pipeline_results.py new file mode 100644 index 0000000..7d8e413 --- /dev/null +++ b/tests/unit_tests/deseq2_core/test_save_pipeline_results.py @@ -0,0 +1,429 @@ +"""Module to test the correct saving of the results. + +This module tests the fact that we recover the desired results both +in the simulation mode and in the subprocess mode, as they have +quite different behaviors. +""" + +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from substrafl import ComputePlanBuilder +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.save_pipeline_results import SavePipelineResults +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.logging import log_save_local_state +from tests.tcga_testing_pipe import run_tcga_testing_pipe + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_save_pipeline_results_on_small_genes( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + save_pipeline_results_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + ) + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_save_pipeline_results_on_subprocess( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, +): + save_pipeline_results_testing_pipe( + raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=False, + backend="subprocess", + only_two_centers=False, + ) + + +@pytest.mark.self_hosted_fast +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_save_pipeline_results_on_small_samples_on_self_hosted_fast( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + save_pipeline_results_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_save_pipeline_results_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + save_pipeline_results_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_save_pipeline_results_on_subprocess_on_self_hosted_slow( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + save_pipeline_results_testing_pipe( + raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=False, + backend="subprocess", + only_two_centers=False, + ) + + +def save_pipeline_results_testing_pipe( + raw_data_path, + processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, +): + """Test the SavePipelineResults class. + + + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + """ + fl_results = run_tcga_testing_pipe( + SavePipelineResultsTester(), + raw_data_path=raw_data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + reference_dds_ref_level=None, + ) + + # Check that all fields are present + assert isinstance(fl_results, dict) + for key in SavePipelineResults.VARM_KEYS: + assert key in fl_results + for key in SavePipelineResults.UNS_KEYS: + assert key in fl_results + assert "gene_names" in fl_results + + +class SavePipelineResultsTester(ComputePlanBuilder, SavePipelineResults): + """Tester for the SavePipelineResults class. + + Parameters + ---------- + n_rounds : int + Number of rounds. + + """ + + def __init__( + self, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__() + + self.local_adata: ad.AnnData | None = None + self.results: dict | None = None + self.refit_adata: ad.AnnData | None = None + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a DESeq2 pipe. + + Parameters + ---------- + train_data_nodes : List[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Load the opener data local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=None, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + shared_state, round_idx = aggregation_step( + aggregation_method=self.create_all_fields_with_dummies, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Create all fields with dummies", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_states, shared_states, round_idx = local_step( + local_method=self.set_dummies_in_local_adata, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=shared_state, + aggregation_id=aggregation_node.organization_id, + description="Set dummies in local adata", + round_idx=round_idx, + clean_models=clean_models, + ) + + self.save_pipeline_results( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds and add the total number of samples. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state with a "num_samples" key. + """ + + self.local_adata = data_from_opener.copy() + return {"num_vars": self.local_adata.n_vars} + + @remote + @log_remote + def create_all_fields_with_dummies(self, shared_states: dict): + """Set all fields with dummies. Used for testing only.""" + num_vars = shared_states[0]["num_vars"] + varm_dummies = {} + + if "replaced" in self.VARM_KEYS: + varm_dummies["replaced"] = np.random.rand(num_vars) > 0.5 + if "refitted" in self.VARM_KEYS: + assert "replaced" in self.VARM_KEYS + varm_dummies["replaced"] = np.random.rand(num_vars) > 0.5 + varm_dummies["refitted"] = varm_dummies["replaced"] & ( + np.random.rand(num_vars) > 0.5 + ) + + for varm_key in self.VARM_KEYS: + if varm_key in {"refitted", "replaced"}: + # Create a random boolean array of size num_vars + continue + varm_dummies[varm_key] = np.random.rand(num_vars, 2) + + uns_dummmies = {} + for uns_key in self.UNS_KEYS: + uns_dummmies[uns_key] = np.random.rand(2) + + return {"varm_dummies": varm_dummies, "uns_dummmies": uns_dummmies} + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_dummies_in_local_adata( + self, + data_from_opener, + shared_state: dict, + ) -> dict: + """Create a dummy array. + + Parameters + ---------- + data_from_opener : Any + Not used. + + shared_state : dict + Shared state with "varm_dummies" and "uns_dummmies" keys. + + """ + for varm_key, varm_dummy in shared_state["varm_dummies"].items(): + self.local_adata.varm[varm_key] = varm_dummy + + for uns_key, uns_dummy in shared_state["uns_dummmies"].items(): + self.local_adata.uns[uns_key] = uns_dummy + + print("Self refit adata") + print(self.refit_adata) + + return {"num_vars": self.local_adata.n_vars} + + @log_save_local_state + def save_local_state(self, path: Path) -> None: + """Save the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to save the state. Automatically handled by subtrafl. + """ + state_to_save = { + "results": self.results, + "local_adata": self.local_adata, + "refit_adata": self.refit_adata, + } + with open(path, "wb") as file: + pkl.dump(state_to_save, file) + + def load_local_state(self, path: Path) -> Any: + """Load the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to load the state from. Automatically handled by + subtrafl. + """ + with open(path, "rb") as file: + state_to_load = pkl.load(file) + + self.results = state_to_load["results"] + self.local_adata = state_to_load["local_adata"] + self.refit_adata = state_to_load["refit_adata"] + + return self + + @property + def num_round(self): + """Return the number of round in the strategy. + + TODO do something clever with this. + + Returns + ------- + int + Number of round in the strategy. + """ + return None diff --git a/tests/unit_tests/deseq2_core/test_size_factors.py b/tests/unit_tests/deseq2_core/test_size_factors.py new file mode 100644 index 0000000..f0e8093 --- /dev/null +++ b/tests/unit_tests/deseq2_core/test_size_factors.py @@ -0,0 +1,466 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pytest +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.deseq2_core.compute_size_factors import ComputeSizeFactors +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +def test_size_factors( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=True, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors: str | list[str] = "stage", + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, +): + """Perform a unit test for the size factors. + + Starting with the same counts as the reference dataset, compute size factors and + compare the results with the reference. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: str + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + """ + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=None, + ) + + reference_data_path = ( + local_processed_data_path / "centers_data" / "tcga" / experiment_id + ) + # Get FL results. + fl_results = run_tcga_testing_pipe( + SizeFactorsTester( + design_factors=design_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=raw_data_path, + processed_data_path=local_processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + local_processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # TODO size factors are sorted because we don't have access to indexes + # There could be issues in case of non-unique values + assert np.allclose( + np.sort(fl_results["size_factors"]), + np.sort(pooled_dds.obsm["size_factors"]), + equal_nan=True, + ) + + +class SizeFactorsTester( + UnitTester, + ComputeSizeFactors, +): + """A class to implement a unit test for the size factors. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + contrast : list or None + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'ref_level']``. + Names must correspond to the metadata data passed to the DeseqDataSet. + E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B' compared + to 'condition A'. + For continuous variables, the last two strings should be left empty, e.g. + ``['measurement', '', ''].`` + If None, the last variable from the design matrix is chosen + as the variable of interest, and the reference level is picked alphabetically. + (default: ``None``). + + Methods + ------- + init_local_states + Local method to initialize the local states. + It returns the design columns of the local design matrix, which is required to + start the computation. + + merge_design_columns_and_build_contrast + Aggregation method to merge the columns of the design matrices and build the + contrast. + This method returns a shared state containing + - merged_columns: the names of the columns that the local design matrices should + have. + - contrast: the contrast (in a list of strings form) to be used for the DESeq2 + These are required to start the first local step of the computation + of size factors. + + compute_local_size_factors + Local method to compute the size factors. + Indeed, the compute_size_factors method only returns an aggregated state + which allows to compute the local size factor and not the size factor itself. + In the main pipeline, the explicit computation of the local size factors + is done in the first local step of compute_mom_dispersions. + + + concatenate_size_factors + Aggregation method to concatenate the size factors. + This method returns the concatenated (pooled equivalent) size factors + + save_size_factors + Local method to save the size factors, combining the two previous methods + as well as the build_share_results_tasks method. + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + contrast: list[str] | None = None, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + contrast=contrast, + ) + + self.contrast = contrast + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to test computing the size factors. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute size factors #### + ( + local_states, + shared_states, + round_idx, + ) = self.compute_size_factors( + train_data_nodes, + aggregation_node, + local_states, + shared_states=shared_states, + round_idx=round_idx, + clean_models=clean_models, + ) + + ### Save results ### + self.save_size_factors( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=False, + ) + + def save_size_factors( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Check size factors. + + Parameters + ---------- + train_data_nodes: list + List of TrainDataNode. + + aggregation_node: AggregationNode + The aggregation node. + + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The current round index. + + clean_models: bool + Whether to clean the models after the computation. + + Returns + ------- + local_states: dict + Local states. Required to propagate intermediate results. + + round_idx: int + The updated round index. + """ + + # ---- Compute and share local size factors ----# + + local_states, shared_states, round_idx = local_step( + local_method=self.get_local_size_factors, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + round_idx=round_idx, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Get the local size factors from the adata", + clean_models=clean_models, + ) + + # ---- Concatenate local size factors ----# + + aggregation_step( + aggregation_method=self.concatenate_size_factors, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Concatenating local size factors", + clean_models=clean_models, + ) + + @remote + @log_remote + def concatenate_size_factors(self, shared_states): + """Concatenate size factors together for registration. + + Use for testing purposes only. + + Parameters + ---------- + shared_states : list + List of results (size_factors) from training nodes. + + Returns + ------- + dict + Concatenated (pooled) size factors. + """ + tot_sf = np.hstack( + [state["size_factors"] for state in shared_states], + ) + self.results = {"size_factors": tot_sf} + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Copy the reference dds to the local state and compute local log mean. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener, used to compute the local log mean. + + shared_state : None, optional + Not used. + + Returns + ------- + dict + Local mean of logs and number of samples. + """ + + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + # This field is not saved in pydeseq2 but used in fedpyseq2 + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + + with np.errstate(divide="ignore"): # ignore division by zero warnings + return { + "log_mean": np.log(data_from_opener.X).mean(axis=0), + "n_samples": data_from_opener.n_obs, + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def get_local_size_factors( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Get the local size factors. + + Parameters + ---------- + data_from_opener : ad.AnnData + Not used. + + shared_state : None + Not used. + + Returns + ------- + dict + A dictionary containing the size factors in the "size_factors" key. + """ + + return {"size_factors": self.local_adata.obsm["size_factors"]} diff --git a/tests/unit_tests/fed_algorithms/__init__.py b/tests/unit_tests/fed_algorithms/__init__.py new file mode 100644 index 0000000..db80c19 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/__init__.py @@ -0,0 +1 @@ +"""Tests for the fed_algorithms module.""" diff --git a/tests/unit_tests/fed_algorithms/fed_IRLS/__init__.py b/tests/unit_tests/fed_algorithms/fed_IRLS/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/fed_algorithms/fed_IRLS/fed_IRLS_tester.py b/tests/unit_tests/fed_algorithms/fed_IRLS/fed_IRLS_tester.py new file mode 100644 index 0000000..7df6700 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_IRLS/fed_IRLS_tester.py @@ -0,0 +1,370 @@ +"""A class to implement a unit tester class for FedIRLS.""" +from pathlib import Path + +import anndata as ad +import numpy as np +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms import FedIRLS +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.unit_tests.fed_algorithms.fed_IRLS_PQN_tester import FedIRLSPQNTester +from tests.unit_tests.fed_algorithms.fed_IRLS_PQN_tester import compute_initial_beta +from tests.unit_tests.unit_test_helpers.pass_on_first_shared_state import ( + AggPassOnFirstSharedState, +) + + +class FedIRLSTester( + FedIRLSPQNTester, FedIRLS, AggPassOnResults, AggPassOnFirstSharedState +): + """A class to implement a unit test for FedIRLS. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + min_mu : float + Lower threshold for mean expression values. (default: ``0.5``). + + beta_tol : float + Tolerance for the beta coefficients. (default: ``1e-8``). + + max_beta : float + Upper threshold for the beta coefficients. (default: ``30``). + + irls_num_iter : int + Number of iterations for the IRLS algorithm. (default: ``20``). + + joblib_backend : str + The backend to use for the IRLS algorithm. (default: ``"loky"``). + + num_jobs : int + Number of CPUs to use for parallelization. (default: ``8``). + + joblib_verbosity : int + Verbosity level for joblib. (default: ``3``). + + irls_batch_size : int + Batch size for the IRLS algorithm. (default: ``100``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + Methods + ------- + init_local_states + A remote_data method which sets the local adata and the local gram matrix + from the reference_dds. + + compute_gram_matrix + A remote method which computes the gram matrix. + + set_gram_matrix + A remote_data method which sets the gram matrix in the local adata. + + build_compute_plan + Build the computation graph to run a FedIRLS algorithm. + + save_irls_results + The method to save the IRLS results. + + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + min_mu: float = 0.5, + beta_tol: float = 1e-8, + max_beta: float = 30, + irls_num_iter: int = 20, + joblib_backend: str = "loky", + num_jobs: int = 8, + joblib_verbosity: int = 3, + irls_batch_size: int = 100, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + min_mu=min_mu, + beta_tol=beta_tol, + max_beta=max_beta, + irls_num_iter=irls_num_iter, + joblib_backend=joblib_backend, + num_jobs=num_jobs, + joblib_verbosity=joblib_verbosity, + irls_batch_size=irls_batch_size, + ) + + #### Define hyper parameters #### + + self.min_mu = min_mu + self.beta_tol = beta_tol + self.max_beta = max_beta + + # Parameters of the IRLS algorithm + self.irls_num_iter = irls_num_iter + + #### Define job parallelization parameters #### + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + self.layers_to_save_on_disk = { + "local_adata": [ + "_mu_hat", + ], + "refit_adata": [ + None, + ], + } + + @remote + @log_remote + def compute_start_state(self, shared_states: list[dict]) -> dict: + """Compute the beta initialization, and share to the centers + + Parameters + ---------- + shared_states : list + List of shared states. + + Returns + ------- + dict + Dictionary containing the starting state of the IRLS algorithm. + It contains the following keys: + - beta: ndarray + The current beta, of shape (n_non_zero_genes, n_params). + - irls_diverged_mask: ndarray + A boolean mask indicating if fed avg should be used for a given gene + (shape: (n_non_zero_genes,)). It is initialized to False. + - irls_mask: ndarray + A boolean mask indicating if IRLS should be used for a given gene + (shape: (n_non_zero_genes,)). It is initialized to True. + - global_nll: ndarray + The global_nll of the current beta from the previous beta, of shape + (n_non_zero_genes,). It is initialized to 1000.0. + - round_number_irls: int + The current round number of the IRLS algorithm. It is initialized to 0. + + """ + + beta_init = compute_initial_beta(shared_states) + n_non_zero_genes = beta_init.shape[0] + + # Share it with the centers + return { + "beta": beta_init, + "irls_mask": np.full(n_non_zero_genes, True), + "irls_diverged_mask": np.full(n_non_zero_genes, False), + "global_nll": np.full(n_non_zero_genes, 1000.0), + "round_number_irls": 0, + } + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : List[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + # This step also shares the counts and design matrix to compute + # the beta initialization directly + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the initialization shared state #### + + starting_shared_state, round_idx = aggregation_step( + aggregation_method=self.compute_start_state, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Compute the global gram matrix.", + clean_models=clean_models, + ) + + #### Set the beta init in the local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.set_beta_init, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=starting_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Set the beta init", + round_idx=round_idx, + clean_models=clean_models, + ) + + starting_shared_state, round_idx = aggregation_step( + aggregation_method=self.pass_on_shared_state, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Pass on the shared state", + clean_models=clean_models, + ) + + #### Perform fed PQN #### + + local_states, irls_shared_state, round_idx = self.run_fed_irls( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + input_shared_state=starting_shared_state, + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Share the results #### + + # Local step to add non_zero genes to the shared state + local_states, shared_states, round_idx = local_step( + local_method=self.local_add_non_zero_genes, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=irls_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Add non zero genes to the shared state", + round_idx=round_idx, + clean_models=clean_models, + ) + + # Save the results + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Save the first shared state", + clean_models=False, + ) + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_add_non_zero_genes( + self, data_from_opener: ad.AnnData, shared_state: dict + ) -> dict: + """Initialize the local states. + + This methods sets the local_adata and the local gram matrix. + from the reference_dds. + + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state which comes from the last iteration of + the Fed Prox Quasi Newton algorithm. + Contains a `beta` and a `PQN_diverged_mask` field. + + Returns + ------- + local_state : dict + The local state containing the non zero genes mask and genes + as an addition to the input shared state. + + """ + non_zero_genes_names = self.local_adata.var_names[ + self.local_adata.varm["non_zero"] + ] + non_zero_genes_mask = self.local_adata.varm["non_zero"] + return { + "irls_diverged_mask": shared_state["irls_diverged_mask"] + | shared_state["irls_mask"], + "beta": shared_state["beta"], + "non_zero_genes_mask": non_zero_genes_mask, + "non_zero_genes_names": non_zero_genes_names, + } diff --git a/tests/unit_tests/fed_algorithms/fed_IRLS/irls_test_pipe.py b/tests/unit_tests/fed_algorithms/fed_IRLS/irls_test_pipe.py new file mode 100644 index 0000000..aaca5e3 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_IRLS/irls_test_pipe.py @@ -0,0 +1,157 @@ +"""Module to implement a testing pipeline for PQN method. + +It consists in computing the log fold changes using the PQN method +directly, and checking that the nll obtained using this method +is lower or better than the one obtained using the standard pipe. +""" + +import pickle as pkl +from pathlib import Path + +import numpy as np +from fedpydeseq2_datasets.constants import TCGADatasetNames +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from substra import BackendType + +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.fed_algorithms.fed_IRLS.fed_IRLS_tester import FedIRLSTester + + +def pipe_test_compute_lfc_with_irls( + data_path: Path, + processed_data_path: Path, + tcga_assets_directory: Path, + dataset_name: TCGADatasetNames = "TCGA-LUAD", + small_samples: bool = False, + small_genes: bool = False, + simulate: bool = True, + backend: BackendType = "subprocess", + only_two_centers: bool = False, + design_factors: str | list[str] = "stage", + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, + rtol: float = 1e-2, + atol: float = 1e-3, +): + r"""Perform a unit test for the log fold change computation with IRLS. + + As IRLS does not always converge, we check that for all genes + that converged, the log fold changes are the same as the ones + obtained with the pooled data. + + Parameters + ---------- + data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: BackendType + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + + rtol: float + The relative tolerance to use for the comparison. + + atol: float + The absolute tolerance to use for the comparison. + + """ + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=None, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + # Get FL results. + fl_results = run_tcga_testing_pipe( + FedIRLSTester( + design_factors=design_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + ), + raw_data_path=data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # FL gene name by convergence type + fl_beta = fl_results["beta"] + fl_irls_diverged_mask = fl_results["irls_diverged_mask"] + fl_non_zero_gene_names = fl_results["non_zero_genes_names"] + converged_gene_names = fl_non_zero_gene_names[~fl_irls_diverged_mask] + + fl_LFC_converged = fl_beta[~fl_irls_diverged_mask, :] + + # pooled LFC results + pooled_LFC_converged = ( + pooled_dds.varm["LFC"].loc[converged_gene_names, :].to_numpy() + ) + + #### ---- Check for the IRLS_converged ---- #### + + LFC_error_tol = np.abs(pooled_LFC_converged) * rtol + atol + LFC_abs_error = np.abs(fl_LFC_converged - pooled_LFC_converged) + + assert np.all(LFC_abs_error < LFC_error_tol) diff --git a/tests/unit_tests/fed_algorithms/fed_IRLS/test_IRLS_utils.py b/tests/unit_tests/fed_algorithms/fed_IRLS/test_IRLS_utils.py new file mode 100644 index 0000000..63e8c51 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_IRLS/test_IRLS_utils.py @@ -0,0 +1,156 @@ +"""Module to test the base functions of the IRLS algorithm.""" + +import numpy as np + +from fedpydeseq2.core.deseq2_core.deseq2_lfc_dispersions.compute_lfc.utils import ( + make_irls_nll_batch, +) +from fedpydeseq2.core.fed_algorithms.fed_irls.utils import ( + make_irls_update_summands_and_nll_batch, +) + + +def test_make_irls_update_summands_and_nll_batch(): + """Test the function make_irls_update_summands_and_nll_batch. + + This test checks that the function returns the correct output shapes. + given input shapes of size (3, 2), (3,), (5, 2), (5,), (3, 5), and a scalar. + """ + # Create fake data + design_matrix = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + size_factors = np.array([1.0, 2.0, 3.0]) + beta = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]) + dispersions = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + counts = np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0], + ] + ) + min_mu = 0.1 + + # Call the function with the fake data + H, y, nll = make_irls_update_summands_and_nll_batch( + design_matrix, size_factors, beta, dispersions, counts, min_mu + ) + + # Check that the outputs are correct + assert H.shape == (5, 2, 2) + # Check that H is symmetric + assert np.allclose(H, H.transpose(0, 2, 1)) + assert y.shape == (5, 2) + assert nll.shape == (5,) + + +def test_make_irls_update_summands_and_nll_batch_no_warnings(): + """Test the function make_irls_update_summands_and_nll_batch. + + This test checks that the function returns the correct output shapes. + given input shapes of size (3, 2), (3,), (5, 2), (5,), (3, 5), and a scalar. + """ + # Create fake data + design_matrix = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + size_factors = np.array([1.0, 2.0, 3.0]) + beta = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [1000.0, 2000.0]]) + dispersions = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + counts = np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0], + ] + ) + min_mu = 0.1 + + # Call the function with the fake data + import warnings + + warnings.filterwarnings("error") + H, y, nll = make_irls_update_summands_and_nll_batch( + design_matrix, size_factors, beta, dispersions, counts, min_mu + ) + + # Check that the outputs are correct + assert H.shape == (5, 2, 2) + # Check that H is symmetric + assert np.allclose(H, H.transpose(0, 2, 1)) + assert y.shape == (5, 2) + assert nll.shape == (5,) + + +def test_make_irls_update_summands_and_nll_batch_single_design(): + """Test the function make_irls_update_summands_and_nll_batch. + + This test checks the border case where the design matrix has only one row, and + there is only one gene. + """ + # Create fake data + design_matrix = np.array([[1.0]]) + size_factors = np.array([1.0]) + beta = np.array([[1.0]]) + dispersions = np.array([1.0]) + counts = np.array([[1.0]]) + min_mu = 0.1 + + # Call the function with the fake data + H, y, nll = make_irls_update_summands_and_nll_batch( + design_matrix, size_factors, beta, dispersions, counts, min_mu + ) + + # Check that the outputs are correct + assert H.shape == (1, 1, 1) + assert y.shape == (1, 1) + assert nll.shape == (1,) + + +def test_make_irls_nll_batch_specific_sizes(): + """Test the function make_irls_nll_batch. + + This test checks that the function returns the correct output shapes + given input shapes of size (5, 2), (3, 2), (3,), (5,), and a scalar. + """ + # Create fake data + beta = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]) + design_matrix = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + size_factors = np.array([1.0, 2.0, 3.0]) + dispersions = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + counts = np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0], + ] + ) + min_mu = 0.1 + + # Call the function with the fake data + nll = make_irls_nll_batch( + beta, design_matrix, size_factors, dispersions, counts, min_mu + ) + + # Check that the outputs are correct + assert nll.shape == (5,) + + +def test_make_irls_nll_batch_single_dim(): + """Test the function make_irls_nll_batch. + + This test checks the border case where the design matrix has only one row, and + there is only one gene. + """ + # Create fake data + beta = np.array([[1.0]]) + design_matrix = np.array([[1.0]]) + size_factors = np.array([1.0]) + dispersions = np.array([1.0]) + counts = np.array([[1.0]]) + min_mu = 0.1 + + # Call the function with the fake data + nll = make_irls_nll_batch( + beta, design_matrix, size_factors, dispersions, counts, min_mu + ) + + # Check that the outputs are correct + assert nll.shape == (1,) diff --git a/tests/unit_tests/fed_algorithms/fed_IRLS/test_irls.py b/tests/unit_tests/fed_algorithms/fed_IRLS/test_irls.py new file mode 100644 index 0000000..75ae1e9 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_IRLS/test_irls.py @@ -0,0 +1,173 @@ +"""Unit test for the Fed Prox Quasi Newton algorithm.""" + +import pytest +from fedpydeseq2_datasets.constants import TCGADatasetNames + +from tests.unit_tests.fed_algorithms.fed_IRLS.irls_test_pipe import ( + pipe_test_compute_lfc_with_irls, +) + +TESTING_PARAMETERS_LIST = [ + "TCGA-LUAD", + "TCGA-PAAD", +] + + +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_lfc_with_irls( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, +): + """Perform a unit test to see if compute_lfc is working as expected. + + Note that the catching of IRLS is very simple here, as there are not enough + genes to observe significant differences in the log fold changes. + + The behaviour of the fed prox algorithm is tested on a self hosted runner. + + Moreover, we only test with the fisher scaling mode, as the other modes are + tested in the other tests, and perform less well in our tested datasets. + + We do not clip mu as this seems to yield better results. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + """ + + pipe_test_compute_lfc_with_irls( + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + rtol=1e-2, + atol=1e-3, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "dataset_name", + TESTING_PARAMETERS_LIST, +) +def test_lfc_with_irls_on_self_hosted( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name: TCGADatasetNames, +): + """Perform a unit test for compute_lfc using the fisher scaling mode. + + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + """ + + pipe_test_compute_lfc_with_irls( + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name=dataset_name, + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + rtol=1e-2, + atol=1e-3, + ) + + +@pytest.mark.local +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "dataset_name", + TESTING_PARAMETERS_LIST, +) +def test_lfc_with_irls_on_local( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name: TCGADatasetNames, +): + """Perform a unit test for compute_lfc. + + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + + """ + + pipe_test_compute_lfc_with_irls( + data_path=raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name=dataset_name, + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + rtol=1e-2, + atol=1e-3, + ) diff --git a/tests/unit_tests/fed_algorithms/fed_IRLS_PQN_tester.py b/tests/unit_tests/fed_algorithms/fed_IRLS_PQN_tester.py new file mode 100644 index 0000000..a6303a0 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_IRLS_PQN_tester.py @@ -0,0 +1,150 @@ +"""A mixin class and a function to implement common test utils for IRLS and PQN.""" + +from typing import Any + +import anndata as ad +import numpy as np +from numpy.linalg import solve +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from tests.unit_tests.unit_test_helpers.unit_tester import UnitTester + + +class FedIRLSPQNTester(UnitTester): + """A mixin class to implement method for IRLS and PQN testing classes. + + + Methods + ------- + init_local_states + A remote_data method, which initializes the local states by setting the local + adata. It also returns the normalized counts and the design matrix, in order + to create the intial beta (Note that this is only for testing purposes) + + + set_beta_init + A remote_data method, which sets the beta init in the local states, and passes + on the shared state which is used as an initialization state for the + Prox Quasi Newton algorithm. + + + """ + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Initialize the local states. + + This methods sets the local_adata and the local gram matrix. + from the reference_dds. + + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Not used. + + Returns + ------- + local_states : dict + Local states containing the local Gram matrix. + """ + + # Using super().init_local_states_from_opener(data_from_opener, shared_state) + # does not work, so we have to duplicate the code + self.local_adata = self.local_reference_dds.copy() + # This field is not saved in pydeseq2 but used in fedpyseq2 + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + non_zero_genes_names = self.local_adata.var_names[ + self.local_adata.varm["non_zero"] + ] + + # Get the local gram matrix for all genes + design_matrix = self.local_adata.obsm["design_matrix"].values + + # Get the counts on non zero genes + normed_counts_non_zero = self.local_adata[:, non_zero_genes_names].layers[ + "normed_counts" + ] + + self.local_adata.uns["_irls_disp_param_name"] = "dispersions" + + return { + "local_normed_counts_non_zero": normed_counts_non_zero, + "local_design_matrix": design_matrix, + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_beta_init(self, data_from_opener: ad.AnnData, shared_state: dict) -> dict: + """Set the beta init. + + Since the fed prox newton algorithm requires the beta init to be + set in the uns, we do so in this step. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + The initial shared state of the fed pqn algorithm. + + Returns + ------- + local_state : dict + The same initial shared state. + + """ + + self.local_adata.uns["_irls_beta_init"] = shared_state["beta"] + return shared_state + + +def compute_initial_beta(shared_states: list[dict]) -> np.ndarray: + """Compute the beta initialization from a list of shared states. + + Parameters + ---------- + shared_states : list + List of shared states. + + Returns + ------- + np.ndarray + The beta initialization. + """ + + # Concatenate the local design matrices + X = np.concatenate([state["local_design_matrix"] for state in shared_states]) + + # Concatenate the normed counts + normed_counts_non_zero = np.concatenate( + [state["local_normed_counts_non_zero"] for state in shared_states] + ) + + # Compute the beta initialization + num_vars = X.shape[1] + n_non_zero_genes = normed_counts_non_zero.shape[1] + + # if full rank, estimate initial betas for IRLS below + if np.linalg.matrix_rank(X) == num_vars: + Q, R = np.linalg.qr(X) + y = np.log(normed_counts_non_zero + 0.1) + beta_init = solve(R[None, :, :], (Q.T @ y).T[:, :]) + else: # Initialise intercept with log base mean + beta_init = np.zeros(n_non_zero_genes, num_vars) + beta_init[:, 0] = np.log(normed_counts_non_zero).mean(axis=0) + + return beta_init diff --git a/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/__init__.py b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/__init__.py new file mode 100644 index 0000000..41539f5 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/__init__.py @@ -0,0 +1 @@ +"""Module to test the prox quasi newton method.""" diff --git a/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/fed_pqn_tester.py b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/fed_pqn_tester.py new file mode 100644 index 0000000..0bc7dd9 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/fed_pqn_tester.py @@ -0,0 +1,380 @@ +"""A class to implement a unit tester class for Fed Prox Quasi Newton.""" +from pathlib import Path + +import anndata as ad +import numpy as np +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms import FedProxQuasiNewton +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.unit_tests.fed_algorithms.fed_IRLS_PQN_tester import FedIRLSPQNTester +from tests.unit_tests.fed_algorithms.fed_IRLS_PQN_tester import compute_initial_beta +from tests.unit_tests.unit_test_helpers.pass_on_first_shared_state import ( + AggPassOnFirstSharedState, +) + + +class FedProxQuasiNewtonTester( + FedIRLSPQNTester, + FedProxQuasiNewton, + AggPassOnResults, + AggPassOnFirstSharedState, +): + """A class to implement a unit test for FedProxQuasiNewton. + + This unit step is performed on an optimization problem + which is solved by ComputeLFC (and therefore by IRLS in a majority + of cases) in the pipeline. Here, we directly use the FedProxQuasiNewton + method as an optimization algorithm. It is therefore quite oriented + towards the particular optimization problem arising in the DESeq2 pipeline. + + Parameters + ---------- + design_factors : Union[str, list[str]] + The design factors. + + ref_levels : Optional[dict[str, str]] + The reference levels. + + continuous_factors : Optional[list[str]] + The continuous factors. + + PQN_min_mu : float + The minimum mu. + + max_beta : float + The maximum beta. + + joblib_backend : str + The IRLS backend. + + num_jobs : int + The number of CPUs. + + joblib_verbosity : int + The joblib verbosity. + + irls_batch_size : int + The IRLS batch size, i.e. the number of genes used per parallelization + batch. + + PQN_c1 : float + The prox quasi newton c_1 constant used in the Armijo line search. + + PQN_ftol : float + The relative tolerance used as a stopping criterion in the Prox Quasi Newton + method. + + PQN_num_iters_ls : int + The number of iterations used in the line search. + + PQN_num_iters : int + The number of iterations used in the Prox Quasi Newton method. + + reference_data_path : Optional[Union[str, Path]] + The path to the reference data. + + reference_dds_ref_level : Optional[tuple[str, ...]] + The reference level of the reference DeseqDataSet. + + + Methods + ------- + init_local_states + A remote_data method, which initializes the local states by setting the local + adata. It also returns the normalized counts and the design matrix, in order + to create the intial beta (Note that this is only for testing purposes) + + compute_start_state + A remote method, which computes the start state by concatenating the local + design matrices and the normed counts, and computing the initial beta value + as in PyDESeq2 (Note that this is only for testing purposes, and is done in + a federated way in the real pipeline). + + set_beta_init + A remote_data method, which sets the beta init in the local states, and passes + on the shared state which is used as an initialization state for the + Prox Quasi Newton algorithm. + + pass_on_shared_state + A remote method, which passes on the shared state to the centers. + + local_add_non_zero_genes + A remote_data method, which adds the non zero genes (stored in the local adata) + to the shared state, and returns this new shared state to the server. + + build_compute_plan + A method to build the computation graph to run the Fed Prox Quasi Newton + algorithm on the LFC computation problem in PyDESeq2. + + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + PQN_min_mu: float = 0.5, + max_beta: float = 30, + joblib_backend: str = "loky", + num_jobs: int = 8, + joblib_verbosity: int = 3, + irls_batch_size: int = 100, + PQN_c1: float = 1e-4, + PQN_ftol: float = 1e-7, + PQN_num_iters_ls: int = 20, + PQN_num_iters: int = 100, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + ): + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + PQN_min_mu=PQN_min_mu, + max_beta=max_beta, + joblib_backend=joblib_backend, + num_jobs=num_jobs, + joblib_verbosity=joblib_verbosity, + irls_batch_size=irls_batch_size, + PQN_c1=PQN_c1, + PQN_ftol=PQN_ftol, + PQN_num_iters_ls=PQN_num_iters_ls, + PQN_num_iters=PQN_num_iters, + ) + + #### Define hyper parameters #### + + self.PQN_min_mu = PQN_min_mu + self.max_beta = max_beta + + # Parameters of the PQN algorithm + self.PQN_c1 = PQN_c1 + self.PQN_ftol = PQN_ftol + self.PQN_num_iters_ls = PQN_num_iters_ls + self.PQN_num_iters = PQN_num_iters + + #### Define job parallelization parameters #### + self.joblib_verbosity = joblib_verbosity + self.num_jobs = num_jobs + self.joblib_backend = joblib_backend + self.irls_batch_size = irls_batch_size + + self.layers_to_save_on_disk = { + "local_adata": [ + "_mu_hat", + ], + "refit_adata": [ + None, + ], + } + + @remote + @log_remote + def compute_start_state(self, shared_states: list[dict]) -> dict: + """Compute the beta initialization, and share to the centers + + Parameters + ---------- + shared_states : list + List of shared states. + + Returns + ------- + dict + Dictionary containing the global gram matrix. + """ + + # Concatenate the local design matrices + beta_init = compute_initial_beta(shared_states) + n_non_zero_genes = beta_init.shape[0] + + # Share it with the centers + return { + "beta": beta_init, + "ascent_direction_on_mask": None, + "PQN_mask": np.ones((n_non_zero_genes,), dtype=bool), + "irls_diverged_mask": np.zeros((n_non_zero_genes,), dtype=bool), + "PQN_diverged_mask": np.zeros((n_non_zero_genes,), dtype=bool), + "global_reg_nll": np.nan * np.ones((n_non_zero_genes,)), + "round_number_PQN": 0, + "newton_decrement_on_mask": None, + } + + @remote_data + @log_remote_data + @reconstruct_adatas + def local_add_non_zero_genes( + self, data_from_opener: ad.AnnData, shared_state: dict + ) -> dict: + """Initialize the local states. + + This methods sets the local_adata and the local gram matrix. + from the reference_dds. + + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : dict + Shared state which comes from the last iteration of + the Fed Prox Quasi Newton algorithm. + Contains a `beta` and a `PQN_diverged_mask` field. + + Returns + ------- + local_state : dict + The local state containing the non zero genes mask and genes + as an addition to the input shared state. + + """ + non_zero_genes_names = self.local_adata.var_names[ + self.local_adata.varm["non_zero"] + ] + non_zero_genes_mask = self.local_adata.varm["non_zero"] + return { + "PQN_diverged_mask": shared_state["PQN_diverged_mask"], + "beta": shared_state["beta"], + "non_zero_genes_mask": non_zero_genes_mask, + "non_zero_genes_names": non_zero_genes_names, + } + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + #### Create a reference DeseqDataset object #### + + local_states, shared_states, round_idx = self.set_local_reference_dataset( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models=clean_models, + ) + + #### Load reference dataset as local_adata and set local states #### + + # This step also shares the counts and design matrix to compute + # the beta initialization directly + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Compute the initialization shared state #### + + starting_shared_state, round_idx = aggregation_step( + aggregation_method=self.compute_start_state, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Compute the global gram matrix.", + clean_models=clean_models, + ) + + #### Set the beta init in the local states #### + + local_states, shared_states, round_idx = local_step( + local_method=self.set_beta_init, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=starting_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Set the beta init", + round_idx=round_idx, + clean_models=clean_models, + ) + + starting_shared_state, round_idx = aggregation_step( + aggregation_method=self.pass_on_shared_state, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + round_idx=round_idx, + description="Pass on the shared state", + clean_models=clean_models, + ) + + #### Perform fed PQN #### + + local_states, pqn_shared_state, round_idx = self.run_fed_PQN( + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + local_states=local_states, + PQN_shared_state=starting_shared_state, + first_iteration_mode=None, + round_idx=round_idx, + clean_models=clean_models, + ) + + #### Share the results #### + + # Local step to add non_zero genes to the shared state + local_states, shared_states, round_idx = local_step( + local_method=self.local_add_non_zero_genes, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=local_states, + input_shared_state=pqn_shared_state, + aggregation_id=aggregation_node.organization_id, + description="Add non zero genes to the shared state", + round_idx=round_idx, + clean_models=clean_models, + ) + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=shared_states, + description="Save the first shared state", + round_idx=round_idx, + clean_models=False, + ) diff --git a/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/pqn_test_pipe.py b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/pqn_test_pipe.py new file mode 100644 index 0000000..6d53f5f --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/pqn_test_pipe.py @@ -0,0 +1,312 @@ +"""Module to implement a testing pipeline for PQN method. + +It consists in computing the log fold changes using the PQN method +directly, and checking that the nll obtained using this method +is lower or better than the one obtained using the standard pipe. +""" + +import pickle as pkl +from pathlib import Path + +import numpy as np +from fedpydeseq2_datasets.constants import TCGADatasetNames +from fedpydeseq2_datasets.utils import get_experiment_id +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from pydeseq2.dds import DeseqDataSet +from substra import BackendType + +from fedpydeseq2.core.utils import vec_loss +from tests.tcga_testing_pipe import run_tcga_testing_pipe +from tests.unit_tests.fed_algorithms.fed_prox_quasi_newton.fed_pqn_tester import ( + FedProxQuasiNewtonTester, +) + + +def pipe_test_compute_lfc_with_pqn( + data_path: Path, + processed_data_path: Path, + tcga_assets_directory: Path, + dataset_name: TCGADatasetNames = "TCGA-LUAD", + small_samples: bool = False, + small_genes: bool = False, + simulate: bool = True, + backend: BackendType = "subprocess", + only_two_centers: bool = False, + design_factors: str | list[str] = "stage", + ref_levels: dict[str, str] | None = {"stage": "Advanced"}, # noqa: B006 + reference_dds_ref_level: tuple[str, ...] | None = None, + PQN_min_mu: float = 0.0, + rtol: float = 0.02, + atol: float = 1e-3, + nll_rtol: float = 0.02, + nll_atol: float = 1e-3, + tolerated_failed_genes: int = 5, +): + r"""Perform a unit test for the log fold change computation. + + Parameters + ---------- + data_path: Path + The path to the root data. + + processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + small_samples: bool + Whether to use a small number of samples. + If True, the number of samples is reduced to 10 per center. + + small_genes: bool + Whether to use a small number of genes. + If True, the number of genes is reduced to 100. + + simulate: bool + If true, use the simulation mode, otherwise use the subprocess mode. + + backend: BackendType + The backend to use. Either "subprocess" or "docker". + + only_two_centers: bool + If true, restrict the data to two centers. + + design_factors: str or list + The design factors to use. + + ref_levels: dict or None + The reference levels of the design factors. + + reference_dds_ref_level: tuple or None + The reference level of the design factors. + + PQN_min_mu: float + The minimum value for mu in the PQN method. + + rtol: float + The relative tolerance for the LFC. + + atol: float + The absolute tolerance for the LFC. + + nll_rtol: float + The relative tolerance for the nll. + + nll_atol: float + The absolute tolerance for the nll. + + tolerated_failed_genes: int + The number of tolerated failed genes. + + """ + + # Setup the ground truth path. + experiment_id = get_experiment_id( + dataset_name, + small_samples, + small_genes, + only_two_centers=only_two_centers, + design_factors=design_factors, + continuous_factors=None, + ) + + reference_data_path = processed_data_path / "centers_data" / "tcga" / experiment_id + # Get FL results. + fl_results = run_tcga_testing_pipe( + FedProxQuasiNewtonTester( + design_factors=design_factors, + ref_levels=ref_levels, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + PQN_min_mu=PQN_min_mu, + ), + raw_data_path=data_path, + processed_data_path=processed_data_path, + assets_directory=tcga_assets_directory, + simulate=simulate, + dataset_name=dataset_name, + small_samples=small_samples, + small_genes=small_genes, + backend=backend, + only_two_centers=only_two_centers, + register_data=True, + design_factors=design_factors, + reference_dds_ref_level=reference_dds_ref_level, + ) + + # pooled dds file name + pooled_dds_file_name = get_ground_truth_dds_name(reference_dds_ref_level) + + pooled_dds_file_path = ( + processed_data_path + / "pooled_data" + / "tcga" + / experiment_id + / f"{pooled_dds_file_name}.pkl" + ) + with open(pooled_dds_file_path, "rb") as file: + pooled_dds = pkl.load(file) + + # FL gene name by convergence type + fl_beta = fl_results["beta"] + fl_PQN_diverged_mask = fl_results["PQN_diverged_mask"] + fl_non_zero_gene_names = fl_results["non_zero_genes_names"] + converged_gene_names = fl_non_zero_gene_names[~fl_PQN_diverged_mask] + diverged_gene_names = fl_non_zero_gene_names[fl_PQN_diverged_mask] + + fl_LFC_converged = fl_beta[~fl_PQN_diverged_mask, :] + fl_LFC_diverged = fl_beta[fl_PQN_diverged_mask, :] + + # pooled LFC results + pooled_LFC_converged = ( + pooled_dds.varm["LFC"].loc[converged_gene_names, :].to_numpy() + ) + pooled_LFC_diverged = pooled_dds.varm["LFC"].loc[diverged_gene_names, :].to_numpy() + + #### ---- Check for the PQN_converged ---- #### + + # For genes that have converged with the prox newton method, + # we check that the LFC are the same + # for the FL and the pooled results. + # If it is not the case, we check the relative log likelihood is not + # too different. + # If that is not the case, we check that the relative optimization error wrt + # the beta init is not too different. + # We tolerate a few failed genes. + + beta_nll_relative_error_testing( + fl_LFC_converged, + pooled_LFC_converged, + pooled_dds, + converged_gene_names, + tolerated_failed_genes=tolerated_failed_genes, + ) + + #### ---- Check for the all_diverged ---- #### + + # We perform the same checks for the genes that have not converged as well. + + beta_nll_relative_error_testing( + fl_LFC_diverged, + pooled_LFC_diverged, + pooled_dds, + diverged_gene_names, + rtol=rtol, + atol=atol, + nll_rtol=nll_rtol, + nll_atol=nll_atol, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +def beta_nll_relative_error_testing( + fl_LFC: np.ndarray, + pooled_LFC: np.ndarray, + pooled_dds: DeseqDataSet, + fl_genes: list[str], + rtol: float = 0.02, + atol: float = 1e-3, + nll_rtol: float = 0.02, + nll_atol: float = 1e-3, + tolerated_failed_genes: int = 5, +): + r"""Testing for genes. + + Parameters + ---------- + + fl_LFC: np.ndarray + The LFC from the FL results. + + pooled_LFC: np.ndarray + The LFC from the pooled results. + + pooled_dds: DeseqDataSet + The pooled DeseqDataSet. + + fl_genes: list[str] + The genes that are not IRLS converged. + + rtol: float + The relative tolerance for the LFC. + + atol: float + The absolute tolerance for the LFC. + + nll_rtol: float + The relative tolerance for the nll. + + nll_atol: float + The absolute tolerance for the nll. + + tolerated_failed_genes: int + The number of tolerated failed genes. + + """ + accepted_error = np.abs(pooled_LFC) * rtol + atol + absolute_error = np.abs(fl_LFC - pooled_LFC) + + # We check that the relative errors are not too high. + to_check_mask = (absolute_error > accepted_error).any(axis=1) + + if np.sum(to_check_mask) > 0: + # For the genes whose relative error is too high, + # We will start by checking the relative error for the nll. + to_check_genes = fl_genes[to_check_mask] + to_check_genes_index = pooled_dds.var_names.get_indexer(to_check_genes) + + counts = pooled_dds[:, to_check_genes_index].X + design = pooled_dds.obsm["design_matrix"].values + dispersions = pooled_dds[:, to_check_genes_index].varm["dispersions"] + + # We compute the mu values for the FL and the pooled results. + size_factors = pooled_dds.obsm["size_factors"] + mu_fl = np.maximum( + size_factors[:, None] * np.exp(design @ fl_LFC[to_check_mask].T), + 0.5, + ) + mu_pooled = np.maximum( + size_factors[:, None] * np.exp(design @ pooled_LFC[to_check_mask].T), + 0.5, + ) + fl_nll = vec_loss( + counts, + design, + mu_fl, + dispersions, + ) + pooled_nll = vec_loss( + counts, + design, + mu_pooled, + dispersions, + ) + + # Note: here I test the nll and not the regularized NLL which is + # the real target of the optimization. However, this should not be + # an issue since we add only a small regularization and there is + # a bound on the beta values. + accepted_nll_error = np.abs(pooled_nll) * nll_rtol + nll_atol + nll_error = fl_nll - pooled_nll + failed_test_mask = accepted_nll_error < nll_error + + # We identify the genes that do not pass the nll relative error criterion. + + if np.sum(failed_test_mask) > 0: + # We tolerate a few failed genes. + print( + "Genes that do not pass the nll relative error criterion with " + f"tolerance {nll_rtol}." + ) + print(to_check_genes[failed_test_mask]) + print("Correponding error and accepted errors: ") + print(nll_error[failed_test_mask]) + print(accepted_nll_error[failed_test_mask]) + + assert np.sum(failed_test_mask) <= tolerated_failed_genes diff --git a/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_fed_prox_newton_utils.py b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_fed_prox_newton_utils.py new file mode 100644 index 0000000..1821709 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_fed_prox_newton_utils.py @@ -0,0 +1,184 @@ +import numpy as np +import pytest + +from fedpydeseq2.core.fed_algorithms.fed_PQN.utils import ( + compute_ascent_direction_decrement, +) +from fedpydeseq2.core.fed_algorithms.fed_PQN.utils import ( + compute_gradient_scaling_matrix_fisher, +) +from fedpydeseq2.core.fed_algorithms.fed_PQN.utils import ( + make_fisher_gradient_nll_step_sizes_batch, +) + + +def test_make_fisher_gradient_nll_step_sizes_batch(): + """Test the function make_fisher_gradient_nll_step_sizes_batch. + + This function runs on matrices with n_obs = 3, n_params=2, n_steps = 4, + n_genes = 5. + """ + # Create fake data + design_matrix = np.array([[1, 2], [3, 4], [5, 6]]) + size_factors = np.array([1, 2, 3]) + beta = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + ascent_direction = np.array([[1, 3], [5, 7], [9, 11], [13, 15], [17, 19]]) + dispersions = np.array([1, 2, 3, 4, 5]) + counts = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]]) + step_sizes = np.array([1, 2, 3, 4]) + + min_mu = 0.5 + beta_min = -30.0 + beta_max = 30.0 + + # Call the function with the fake data + H, gradient, nll = make_fisher_gradient_nll_step_sizes_batch( + design_matrix=design_matrix, + size_factors=size_factors, + beta=beta, + dispersions=dispersions, + counts=counts, + ascent_direction=ascent_direction, + step_sizes=step_sizes, + beta_min=beta_min, + beta_max=beta_max, + min_mu=min_mu, + ) + + # Check that the outputs are correct + assert H.shape == (4, 5, 2, 2) + # Check that H is symmetric + assert np.allclose(H, H.transpose(0, 1, 3, 2)) + + assert gradient.shape == (4, 5, 2) + assert nll.shape == (4, 5) + + +def test_make_fisher_gradient_nll_step_sizes_batch_none(): + """Test the function make_fisher_gradient_nll_step_sizes_batch. + + This function runs on matrices with n_obs = 3, n_params=2, n_steps = 4, + n_genes = 5. + """ + # Create fake data + design_matrix = np.array([[1, 2], [3, 4], [5, 6]]) + size_factors = np.array([1, 2, 3]) + beta = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + ascent_direction = None + dispersions = np.array([1, 2, 3, 4, 5]) + counts = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]]) + step_sizes = None + + min_mu = 0.5 + beta_min = -30.0 + beta_max = 30.0 + + # Call the function with the fake data + H, gradient, nll = make_fisher_gradient_nll_step_sizes_batch( + design_matrix=design_matrix, + size_factors=size_factors, + beta=beta, + dispersions=dispersions, + counts=counts, + ascent_direction=ascent_direction, + step_sizes=step_sizes, + beta_min=beta_min, + beta_max=beta_max, + min_mu=min_mu, + ) + + # Check that the outputs are correct + assert H.shape == (1, 5, 2, 2) + # Check that H is symmetric + assert np.allclose(H, H.transpose(0, 1, 3, 2)) + + assert gradient.shape == (1, 5, 2) + assert nll.shape == (1, 5) + + +def test_make_fisher_gradient_nll_step_sizes_batch_single(): + """Test the function make_fisher_gradient_nll_step_sizes_batch. + + This test runs on matrices with only one element. + + """ + # Create fake data + design_matrix = np.array([[1]]) + size_factors = np.array([1]) + beta = np.array([[1]]) + ascent_direction = np.array([[1]]) + step_sizes = np.array([1]) + dispersions = np.array([1]) + counts = np.array([[1]]) + + min_mu = 0.5 + beta_min = -30.0 + beta_max = 30.0 + + # Call the function with the fake data + H, gradient, nll = make_fisher_gradient_nll_step_sizes_batch( + design_matrix=design_matrix, + size_factors=size_factors, + beta=beta, + dispersions=dispersions, + counts=counts, + ascent_direction=ascent_direction, + step_sizes=step_sizes, + beta_min=beta_min, + beta_max=beta_max, + min_mu=min_mu, + ) + + assert H.shape == (1, 1, 1, 1) + assert gradient.shape == (1, 1, 1) + assert nll.shape == (1, 1) + + +@pytest.mark.parametrize("num_jobs", [1, 2]) +def test_compute_gradient_scaling_matrix_fisher_fisher(num_jobs: int): + """Test the function compute_gradient_scaling_matrix_fisher_fisher.""" + # Create the fisher matrix + fisher = np.array([[[2, 0], [0, 2]], [[1, 0], [0, 1]]]) + + # Call the function + result = compute_gradient_scaling_matrix_fisher( + fisher=fisher, + backend="threading", + num_jobs=1, + joblib_verbosity=0, + batch_size=1, + ) + + # Create the expected result + expected_result = np.array([[[0.5, 0], [0, 0.5]], [[1, 0], [0, 1]]]) + + # Check that the result is correct + assert np.allclose(result, expected_result) + + +def test_compute_ascent_direction_decrement(): + """Test the function compute_ascent_direction_decrement.""" + # Create the inputs + gradient_scaling_matrix = np.array([[[1, 0], [0, 1]], [[1, 0], [0, 1]]]) + gradient = np.array([[2, 3], [4, 5]]) + beta = np.array([[1, 2], [3, 4]]) + max_beta = 5 + + ascent_direction, newton_decrement = compute_ascent_direction_decrement( + gradient_scaling_matrix=gradient_scaling_matrix, + gradient=gradient, + beta=beta, + max_beta=max_beta, + ) + + # Create the expected results + expected_ascent_direction = np.array([[2, 3], [4, 5]]) + expected_newton_decrement = np.array([13, 41]) + + # Check that the results are correct + assert np.allclose( + ascent_direction, expected_ascent_direction + ), "The ascent direction does not match the expected ascent direction" + assert np.allclose( + newton_decrement, expected_newton_decrement + ), "The newton decrement does not match the expected newton decrement" diff --git a/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_pqn.py b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_pqn.py new file mode 100644 index 0000000..3c72280 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/fed_prox_quasi_newton/test_pqn.py @@ -0,0 +1,199 @@ +"""Unit test for the Fed Prox Quasi Newton algorithm.""" + +import pytest +from fedpydeseq2_datasets.constants import TCGADatasetNames + +from tests.unit_tests.fed_algorithms.fed_prox_quasi_newton.pqn_test_pipe import ( + pipe_test_compute_lfc_with_pqn, +) + +TESTING_PARAMTERS_LIST = [ + ("TCGA-LUAD", 0.0, 5), + ("TCGA-PAAD", 0.0, 5), + ("TCGA-LUAD", 0.5, 50), + ("TCGA-PAAD", 0.5, 50), +] + + +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +def test_lfc_with_pqn( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + PQN_min_mu=0.0, + tolerated_failed_genes=2, +): + """Perform a unit test to see if compute_lfc is working as expected. + + Note that the catching of IRLS is very simple here, as there are not enough + genes to observe significant differences in the log fold changes. + + The behaviour of the fed prox algorithm is tested on a self hosted runner. + + Moreover, we only test with the fisher scaling mode, as the other modes are + tested in the other tests, and perform less well in our tested datasets. + + We do not clip mu as this seems to yield better results. + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + PQN_min_mu: float + The minimum mu in the prox newton method. + + tolerated_failed_genes: int + The number of tolerated failed genes. + Is set to 2 by default. + + """ + + pipe_test_compute_lfc_with_pqn( + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name="TCGA-LUAD", + small_samples=False, + small_genes=True, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=PQN_min_mu, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +@pytest.mark.self_hosted_slow +@pytest.mark.usefixtures( + "raw_data_path", "tmp_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "dataset_name, PQN_min_mu, tolerated_failed_genes", + TESTING_PARAMTERS_LIST, +) +def test_lfc_with_pqn_on_self_hosted( + raw_data_path, + tmp_processed_data_path, + tcga_assets_directory, + dataset_name: TCGADatasetNames, + PQN_min_mu: bool, + tolerated_failed_genes: int, +): + """Perform a unit test for compute_lfc using the fisher scaling mode. + + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + tmp_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + PQN_min_mu: float + The minimum mu in the prox newton method. + + tolerated_failed_genes: int + The number of tolerated failed genes. + + """ + + pipe_test_compute_lfc_with_pqn( + data_path=raw_data_path, + processed_data_path=tmp_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name=dataset_name, + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=PQN_min_mu, + tolerated_failed_genes=tolerated_failed_genes, + ) + + +@pytest.mark.local +@pytest.mark.usefixtures( + "raw_data_path", "local_processed_data_path", "tcga_assets_directory" +) +@pytest.mark.parametrize( + "dataset_name, PQN_min_mu, tolerated_failed_genes", + TESTING_PARAMTERS_LIST, +) +def test_lfc_with_pqn_on_local( + raw_data_path, + local_processed_data_path, + tcga_assets_directory, + dataset_name: TCGADatasetNames, + PQN_min_mu: float, + tolerated_failed_genes: int, +): + """Perform a unit test for compute_lfc. + + + Parameters + ---------- + raw_data_path: Path + The path to the root data. + + local_processed_data_path: Path + The path to the processed data. The subdirectories will + be created if needed + + tcga_assets_directory: Path + The path to the assets directory. It must contain the + opener.py file and the description.md file. + + dataset_name: TCGADatasetNames + The name of the dataset, for example "TCGA-LUAD". + + PQN_min_mu: float + The minimum mu in the prox newton method. + + tolerated_failed_genes: int + The number of tolerated failed genes. + + """ + + pipe_test_compute_lfc_with_pqn( + data_path=raw_data_path, + processed_data_path=local_processed_data_path, + tcga_assets_directory=tcga_assets_directory, + dataset_name=dataset_name, + small_samples=False, + small_genes=False, + simulate=True, + backend="subprocess", + only_two_centers=False, + design_factors="stage", + ref_levels={"stage": "Advanced"}, + reference_dds_ref_level=None, + PQN_min_mu=PQN_min_mu, + tolerated_failed_genes=tolerated_failed_genes, + ) diff --git a/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/__init__.py b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/__init__.py b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/description.md b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/description.md new file mode 100644 index 0000000..4831e7a --- /dev/null +++ b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/description.md @@ -0,0 +1 @@ +Test opener diff --git a/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/opener.py b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/opener.py new file mode 100644 index 0000000..237a8b9 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/opener/opener.py @@ -0,0 +1,52 @@ +import pathlib + +import anndata as ad +import pandas as pd +import substratools as tools + + +class SimpleOpener(tools.Opener): + """Opener class for testing purposes. + + Creates an AnnData object from a path containing a counts_data.csv and a + metadata.csv. + """ + + def fake_data(self, n_samples=None): + """Create a fake AnnData object for testing purposes. + + Parameters + ---------- + n_samples : int + Number of samples to generate. If None, generate 100 samples. + + Returns + ------- + AnnData + An AnnData object with fake counts and metadata. + """ + pass + + def get_data(self, folders): + """get the data + + Parameters + ---------- + folders : list + List of paths to the dataset folders, whose first element should contain a + counts_data.csv and a metadata.csv file. + + Returns + ------- + AnnData + An AnnData object containing the counts and metadata loaded for the FL pipe. + """ + data_path = pathlib.Path(folders[0]).resolve() + counts_data = pd.read_csv(data_path / "counts_data.csv", index_col=0) + metadata = pd.read_csv(data_path / "metadata.csv", index_col=0) + with open(data_path / "layer_used.txt") as f: + layer_used = f.readline().strip() + + adata = ad.AnnData(X=counts_data, obs=metadata) + adata.layers[layer_used] = adata.X.copy() + return adata diff --git a/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/test_trimmed_mean_strategy.py b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/test_trimmed_mean_strategy.py new file mode 100644 index 0000000..6fdc889 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/test_trimmed_mean_strategy.py @@ -0,0 +1,229 @@ +"""Implements a function running a substrafl experiment with tcga dataset.""" +import itertools +import os +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from pydeseq2.utils import trimmed_mean +from substra.sdk.schemas import DataSampleSpec +from substra.sdk.schemas import DatasetSpec +from substra.sdk.schemas import Permissions +from substrafl.experiment import simulate_experiment +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode + +from fedpydeseq2.substra_utils.utils import get_client +from tests.unit_tests.fed_algorithms.trimmed_mean_strategy.trimmed_mean_strategy import ( # noqa: E501 + TrimmedMeanStrategyForTesting, +) + +LIST_MOCK_DATA_TYPE = [ + "random_2", + "random_5", + "random_10", + "duplicate_5", + "one_center_tie_case_1", + "one_center_tie_case_2", + "one_center_tie_case_3", + "5_center_with_ties", +] + + +def get_data(data_type): + if data_type.startswith("random_"): + n_centers = int(data_type.split("_")[-1]) + n_genes = 100 + list_counts = [] + for _ in range(n_centers): + n_samples = np.random.randint(80, 120) + list_counts.append( + np.random.randint(0, 1000, size=(n_samples, n_genes)).astype(float) + ) + elif data_type.startswith("duplicate_"): + n_centers = int(data_type.split("_")[-1]) + n_genes = 100 + list_counts = [] + counts = np.random.randint(0, 1000, size=(100, n_genes)).astype(float) + for _ in range(n_centers): + list_counts.append(counts.copy()) + elif data_type == "one_center_tie_case_1": + n_genes = 100 + n_samples = 100 + counts = np.random.randint(0, 1000, size=(n_samples, n_genes)).astype(float) + mask = np.random.choice([True, False], size=(n_samples, n_genes)) + counts[mask] = 0.0 + list_counts = [counts] + elif data_type == "one_center_tie_case_2": + n_genes = 100 + n_samples = 100 + counts = np.random.randint(0, 1000, size=(n_samples, n_genes)).astype(float) + mask = np.random.choice([True, False], size=(n_samples, n_genes)) + counts[mask] = 1000.0 + list_counts = [counts] + elif data_type == "one_center_tie_case_3": + n_genes = 100 + n_samples = 100 + counts = np.random.randint(0, 1000, size=(n_samples, n_genes)).astype(float) + mask = np.random.choice([-1, 0, 1], size=(n_samples, n_genes)) + counts[mask == -1] = 0.0 + counts[mask == 1] = 1000.0 + list_counts = [counts] + elif data_type == "5_center_with_ties": + n_genes = 100 + n_samples = 100 + list_counts = [] + for _ in range(5): + counts = np.random.randint(0, 1000, size=(n_samples, n_genes)).astype(float) + mask = np.random.choice([-1, 0, 1], size=(n_samples, n_genes)) + counts[mask == -1] = 0.0 + counts[mask == 1] = 1000.0 + list_counts.append(counts) + else: + raise ValueError(f"Unknown data type: {data_type}") + return list_counts + + +@pytest.mark.parametrize( + "data_type, trim_ratio, nb_iter, layer_used, refit", + itertools.product( + LIST_MOCK_DATA_TYPE, + [0.1, 0.125, 0.2], + [40, 50], + ["layer_1", "layer_2"], + [True, False], + ), +) +def test_trimmed_mean_strategy(data_type, trim_ratio, nb_iter, layer_used, refit): + strategy = TrimmedMeanStrategyForTesting( + trim_ratio=trim_ratio, layer_used=layer_used, nb_iter=nb_iter, refit=refit + ) + list_counts = get_data(data_type) + n_centers = len(list_counts) + n_clients = n_centers + 1 + backend = "subprocess" + exp_path = Path(tempfile.mkdtemp()) + assets_directory = Path(__file__).parent / "opener" + + clients_ = [get_client(backend_type=backend) for _ in range(n_clients)] + + clients = { + client.organization_info().organization_id: client for client in clients_ + } + + # Store organization IDs + all_orgs_id = list(clients.keys()) + algo_org_id = all_orgs_id[0] # Algo provider is defined as the first organization. + data_providers_ids = all_orgs_id[ + 1: + ] # Data providers orgs are the remaining organizations. + + dataset_keys = {} + train_datasample_keys = {} + list_df_counts = [] + + for i, org_id in enumerate(data_providers_ids): + client = clients[org_id] + permissions_dataset = Permissions(public=True, authorized_ids=all_orgs_id) + dataset = DatasetSpec( + name="Test", + type="csv", + data_opener=assets_directory / "opener.py", + description=assets_directory / "description.md", + permissions=permissions_dataset, + logs_permission=permissions_dataset, + ) + print( + f"Adding dataset to client " + f"{str(client.organization_info().organization_id)}" + ) + dataset_keys[org_id] = client.add_dataset(dataset) + print("Dataset added. Key: ", dataset_keys[org_id]) + + os.makedirs(exp_path / f"dataset_{org_id}", exist_ok=True) + n_genes = list_counts[i].shape[1] + n_samples = list_counts[i].shape[0] + columns = [f"gene_{i}" for i in range(n_genes)] + index = [f"sample_{i}" for i in range(n_samples)] + # set seed + + df_counts = pd.DataFrame( + list_counts[i], + index=index, + columns=columns, + ) + df_metadata = pd.DataFrame( + np.random.randint(0, 1000, size=(n_samples, n_genes)), + index=index, + ) + list_df_counts.append(df_counts) + df_counts.to_csv(exp_path / f"dataset_{org_id}" / "counts_data.csv") + df_metadata.to_csv(exp_path / f"dataset_{org_id}" / "metadata.csv") + with open(exp_path / f"dataset_{org_id}" / "layer_used.txt", "w") as f: + f.write(layer_used) + + data_sample = DataSampleSpec( + data_manager_keys=[dataset_keys[org_id]], + path=exp_path / f"dataset_{org_id}", + ) + train_datasample_keys[org_id] = client.add_data_sample(data_sample) + + aggregation_node = AggregationNode(algo_org_id) + + train_data_nodes = [] + + for org_id in data_providers_ids: + # Create the Train Data Node (or training task) and save it in a list + train_data_node = TrainDataNode( + organization_id=org_id, + data_manager_key=dataset_keys[org_id], + data_sample_keys=[train_datasample_keys[org_id]], + ) + train_data_nodes.append(train_data_node) + + _, intermediate_train_state, intermediate_state_agg = simulate_experiment( + client=clients[algo_org_id], + strategy=strategy, + train_data_nodes=train_data_nodes, + evaluation_strategy=None, + aggregation_node=aggregation_node, + clean_models=True, + num_rounds=strategy.num_round, + experiment_folder=exp_path, + ) + + # Gather results from the aggregation node + + agg_client_id_mask = [ + w == clients[algo_org_id].organization_info().organization_id + for w in intermediate_state_agg.worker + ] + + agg_round_id_mask = [ + r == max(intermediate_state_agg.round_idx) + for r in intermediate_state_agg.round_idx + ] + + agg_state_idx = np.where( + [r and w for r, w in zip(agg_round_id_mask, agg_client_id_mask, strict=False)] + )[0][0] + + fl_results = intermediate_state_agg.state[agg_state_idx].results + + total_df_counts = pd.concat(list_df_counts, axis=0) + # in refit mode, we only keep the first 10 genes + if refit: + total_df_counts = total_df_counts.iloc[:, :10] + + pooled_trimmed_mean = trimmed_mean(total_df_counts, trim=trim_ratio, axis=0) + + assert np.allclose( + fl_results[f"trimmed_mean_{layer_used}"], pooled_trimmed_mean, rtol=1e-6 + ), ( + "Trimmed mean is not the same : " + + str(fl_results[f"trimmed_mean_{layer_used}"]) + + " vs " + + str(pooled_trimmed_mean) + ) diff --git a/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/trimmed_mean_strategy.py b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/trimmed_mean_strategy.py new file mode 100644 index 0000000..f0b5710 --- /dev/null +++ b/tests/unit_tests/fed_algorithms/trimmed_mean_strategy/trimmed_mean_strategy.py @@ -0,0 +1,196 @@ +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +from substrafl import ComputePlanBuilder +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote_data + +from fedpydeseq2.core.fed_algorithms import ComputeTrimmedMean +from fedpydeseq2.core.utils import aggregation_step +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.logging import log_save_local_state +from fedpydeseq2.core.utils.pass_on_results import AggPassOnResults +from tests.unit_tests.unit_test_helpers.set_local_reference import SetLocalReference + + +class TrimmedMeanStrategyForTesting( + ComputePlanBuilder, ComputeTrimmedMean, SetLocalReference, AggPassOnResults +): + def __init__( + self, + trim_ratio: float, + layer_used: str, + nb_iter: int, + refit: bool = False, + save_layers_to_disk: bool = False, + *args, + **kwargs, + ): + self.results = {} + self.trim_ratio = trim_ratio + self.layer_used = layer_used + self.nb_iter = nb_iter + self.refit = refit + self.reference_data_path = None + self.local_adata: ad.AnnData | None = None + self.refit_adata: ad.AnnData | None = None + self.results: dict | None = None + super().__init__() + + #### Save layers to disk + self.save_layers_to_disk = save_layers_to_disk + + self.layers_to_save_on_disk = { + "local_adata": [layer_used], + "refit_adata": [layer_used], + } + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states_from_opener, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=None, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + ( + local_states, + final_trimmed_mean_agg_share_state, + round_idx, + ) = self.compute_trim_mean( + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + layer_used=self.layer_used, + mode="normal", + trim_ratio=self.trim_ratio, + n_iter=self.nb_iter, + refit=self.refit, + ) + + aggregation_step( + aggregation_method=self.pass_on_results, + train_data_nodes=train_data_nodes, + aggregation_node=aggregation_node, + input_shared_states=[final_trimmed_mean_agg_share_state], + round_idx=round_idx, + clean_models=False, + ) + + @log_save_local_state + def save_local_state(self, path: Path) -> None: + """Save the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to save the state. Automatically handled by subtrafl. + """ + state_to_save = { + "local_adata": self.local_adata, + "refit_adata": self.refit_adata, + "results": self.results, + } + with open(path, "wb") as file: + pkl.dump(state_to_save, file) + + def load_local_state(self, path: Path) -> Any: + """Load the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to load the state from. Automatically handled by + subtrafl. + """ + with open(path, "rb") as file: + state_to_load = pkl.load(file) + + self.local_adata = state_to_load["local_adata"] + self.refit_adata = state_to_load["refit_adata"] + self.results = state_to_load["results"] + + return self + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states_from_opener( + self, data_from_opener: ad.AnnData, shared_state: Any + ): + """Copy the reference dds to the local state. + + If necessary, to overwrite in child classes to add relevant local states. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + shared_state : Any + Shared state. Not used. + """ + + self.local_adata = data_from_opener.copy() + # Subsample the first 10 columns to use as fake outlier genes + self.refit_adata = self.local_adata[:, :10].copy() + + @property + def num_round(self): + """Return the number of round in the strategy. + + TODO do something clever with this. + + Returns + ------- + int + Number of round in the strategy. + """ + return None + + def get_result(self): + """Return the statistic computed. + + Returns + ------- + dict + The global statistics. + """ + return self.results diff --git a/tests/unit_tests/layers/__init__.py b/tests/unit_tests/layers/__init__.py new file mode 100644 index 0000000..4b00149 --- /dev/null +++ b/tests/unit_tests/layers/__init__.py @@ -0,0 +1 @@ +"""Module to test the reconstruct_adatas decorator.""" diff --git a/tests/unit_tests/layers/layers_tester.py b/tests/unit_tests/layers/layers_tester.py new file mode 100644 index 0000000..0d1aa0a --- /dev/null +++ b/tests/unit_tests/layers/layers_tester.py @@ -0,0 +1,528 @@ +"""Module implenting a tester class for the reconstruct_adatas decorator.""" + +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +from substrafl import ComputePlanBuilder +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from tests.unit_tests.layers.utils import create_dummy_adata_with_layers + + +class SimpleLayersTester( + ComputePlanBuilder, +): + """A tester class for the reconstruct_adatas decorator. + + This class implements the following steps. + First, it creates a dummy AnnData object with the necessary layers, varm, obsm + obs to be able to reconstruct all simple layers. + + We then set this dummy AnnData object as the local_adata and local_reference_adata + attributes of the class. + + We then test the reconstruct_adatas decorator which for now only works on the + local_adata # TODO add the refit adata once it is done + by performing the following steps: + + 1. We perform an empty local step with the reconstruct_adatas decorator. + This ensures + that we have applied the decorator. + + 2. We perform a check without the decorator. In this check, we check the state + of the local_adata before the decorator is applied. + In the case where the + save_layers_to_disk attribute is set to False, we check that + the the local_adata contains no layers nor counts, except for the layers that + have been set in the layers_to_save attribute. If the cooks layer was present in + the reference adata, we check that it is still present in the local_adata, as the + decorator should not touch it. + We check that the values of the layers are the same. + + In the case where the save_layers_to_disk attribute is set to True, or that this + attribute does not exist at all we check that all layers that were present in the + reference adata are still present in the local_adata, and that the layers that were + not present are still not present. We also check that the values of the layers are + the same. + + 3. We perform a check with the decorator. In this check, we check the state of the + local_adata after the decorator is applied once more. + In the case where the save_layers_to_disk attribute is set to False, the + local adata must contain all the layers that are present in the layers_to_save + attribute if it exists, the cooks layer if it was present in the reference adata, + and all the layers that are present in the layers_to_load attribute if it exists. + If the layers_to_load attribute does not exist, we check that all layers that were + present in the reference adata are still present in the local_adata. + We also check that + the values of the layers are the same. We check that the counts are present and + that they are equal. + + In the case where the save_layers_to_disk attribute is set to True, or that this + attribute does not exist at all we check that all layers that were present in the + reference adata are still present in the local_adata, and that the layers that were + not present are still not present. We also check that the values of the layers are + the same. We check that the counts are present and that they are equal. + + Parameters + ---------- + num_row_values_equal_n_params : bool, optional + Whether the number of values taken by the design is equal to + the number of parameters. + If False, the number of row values is equal to the number of parameters + 1. + Defaults to False. + + add_cooks : bool, optional + Whether to add the cooks layer. Defaults to False. + + n_params : int, optional + Number of parameters. Defaults to 5. + + has_save_layers_to_disk_attribute : bool, optional + Whether the class has the save_layers_to_disk attribute. Defaults to False. + + save_layers_to_disk : bool, optional + Whether to save the layers to disk. Defaults to False. + + has_layers_to_save_on_disk_attribute : bool, optional + Whether the class has the layers_to_save_on_disk attribute. + + test_layers_to_save_on_disk_attribute: bool + Whether the layers_to_save_on_disk contains layers or not. + + test_layers_to_save : bool, optional + Whether to test the layers to save. Defaults to False. + + test_layers_to_load : bool, optional + Whether to test the layers to load. Defaults to False. + + """ + + def __init__( + self, + num_row_values_equal_n_params=False, + add_cooks=False, + save_cooks=False, + n_params=5, + has_save_layers_to_disk_attribute: bool = False, + save_layers_to_disk: bool = False, + has_layers_to_save_on_disk_attribute: bool = False, + test_layers_to_save_on_disk_attribute: bool = True, + test_layers_to_save: bool = False, + test_layers_to_load: bool = False, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + num_row_values_equal_n_params=num_row_values_equal_n_params, + add_cooks=add_cooks, + n_params=n_params, + has_save_layers_to_disk_attribute=has_save_layers_to_disk_attribute, + save_layers_to_disk=save_layers_to_disk, + has_layers_to_save_on_disk_attribute=has_layers_to_save_on_disk_attribute, + test_layers_to_save=test_layers_to_save, + test_layers_to_load=test_layers_to_load, + ) + + #### Quantities needed to create the adata + if not add_cooks: + assert not save_cooks, "Cannot save Cooks layer if Cooks not added." + + self.num_row_values_equal_n_params = num_row_values_equal_n_params + self.add_cooks = add_cooks + self.n_params = n_params + + #### Quantities needed to save and load the layers + + if has_save_layers_to_disk_attribute: + self.save_layers_to_disk = save_layers_to_disk + + self.has_layers_to_save_on_disk_attributes = ( + has_layers_to_save_on_disk_attribute + ) + self.test_layers_to_save = test_layers_to_save + self.test_layers_to_load = test_layers_to_load + self.save_cooks = save_cooks + if ( + has_layers_to_save_on_disk_attribute + and test_layers_to_save_on_disk_attribute + ): + self.layers_to_save_on_disk = { + "local_adata": ["_fit_lin_mu_hat", "cooks"] + if save_cooks + else ["_fit_lin_mu_hat"], + "refit_adata": None, + } + elif save_cooks: + self.layers_to_save_on_disk = { + "local_adata": ["cooks"], + "refit_adata": None, + } + + self.local_adata: ad.AnnData | None = None + self.refit_adata: ad.AnnData | None = None + self.local_reference_adata: ad.AnnData | None = None + + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run the test of the decorator. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + + round_idx = 0 + local_states: dict[str, LocalStateRef] = {} + + local_states, shared_states, round_idx = local_step( + local_method=self.init_local_states, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=None, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Initialize local states", + round_idx=round_idx, + clean_models=clean_models, + ) + + if self.test_layers_to_save: + layers_to_save_on_disk = { + "local_adata": ["_irls_mu_hat"], + "refit_adata": None, + } + else: + layers_to_save_on_disk = None + if self.test_layers_to_load: + layers_to_load = { + "local_adata": ["_mu_hat"] + if not self.save_cooks + else ["_mu_hat", "cooks"], + "refit_adata": None, + } + else: + layers_to_load = None + + local_states, _, round_idx = local_step( + local_method=self.empty_remote_data_method, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=None, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Empty remote data method", + round_idx=round_idx, + clean_models=clean_models, + method_params={ + "layers_to_save_on_disk": layers_to_save_on_disk, + "layers_to_load": layers_to_load, + }, + ) + + local_states, _, round_idx = local_step( + local_method=self.perform_check_without_decorator, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=None, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Check without decorator", + round_idx=round_idx, + clean_models=clean_models, + ) + + local_step( + local_method=self.perform_check_with_decorator, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=None, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Check with decorator", + round_idx=round_idx, + clean_models=clean_models, + method_params={ + "layers_to_save_on_disk": layers_to_save_on_disk, + "layers_to_load": layers_to_load, + }, + ) + + @remote_data + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> None: + """ + Initialize the local states of the strategy. + + This method creates a dummy AnnData object with the necessary layers, varm, obsm + obs to be able to reconstruct all simple layers. We load the + counts from the data_from_opener. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. + + shared_state : Any + Shared state. Not used. + + """ + dummy_adata = create_dummy_adata_with_layers( + data_from_opener=data_from_opener, + num_row_values_equal_n_params=self.num_row_values_equal_n_params, + add_cooks=self.add_cooks, + n_params=self.n_params, + ) + + self.local_adata = dummy_adata.copy() + self.refit_adata = None + self.local_reference_adata = dummy_adata.copy() + + @remote_data + @reconstruct_adatas + def empty_remote_data_method( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> dict: + """Empty method to test the reconstruct_adatas decorator. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state. Not used. + + Returns + ------- + dict + An empty dictionary. + """ + return {} + + @remote_data + def perform_check_without_decorator( + self, data_from_opener: ad.AnnData, shared_state: Any + ): + """ + Check the state of the local adata before the decorator is applied. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + + shared_state : Any + Shared state. Not used. + + """ + + all_layers_saved = check_if_all_layers_saved(self) + if all_layers_saved: + return + + # From now on, the save_layers_to_disk attribute exists and is False + + # The layers that should be present are the layers that are either in + # the layers_to_save_on_disk attribute, or the ones that were fed + # to the method params. + + should_be_present = {"local_adata": [], "refit_adata": []} + + # We look at the layers that were defined to be saved globally. + if ( + hasattr(self, "layers_to_save_on_disk") + and self.layers_to_save_on_disk is not None + ): + for adata_name in {"local_adata", "refit_adata"}: + to_save_on_disk_global_adata = self.layers_to_save_on_disk[adata_name] + if to_save_on_disk_global_adata is not None: + should_be_present[adata_name].extend(to_save_on_disk_global_adata) + + # If we must save cooks, add them here. + elif self.save_cooks: + should_be_present["local_adata"].append("cooks") + + # Add all the layers in the layers_to_save_on_disk method parameter if it exists + if self.test_layers_to_save: + should_be_present["local_adata"].append("_irls_mu_hat") + + # Test that all layers that should be present are present, and only those. + # Check equality on the present layers. + + for adata_name in {"local_adata", "refit_adata"}: + adata = getattr(self, adata_name) + should_be_present_adata = should_be_present[adata_name] + if adata is None: + assert len(should_be_present_adata) == 0 + else: + assert set(adata.layers.keys()) == set(should_be_present_adata) + if adata_name == "refit_adata": + continue # TODO implement check once we really implement the test + for layer_name in list(should_be_present_adata): + assert np.allclose( + self.local_reference_adata.layers[layer_name], + adata.layers[layer_name], + equal_nan=True, + ) + # TODO do the corresponding check for the refit adata. + assert self.local_adata.X is None + + @remote_data + @reconstruct_adatas + def perform_check_with_decorator( + self, data_from_opener: ad.AnnData, shared_state: Any + ): + """ + Check the state of the local adata after the decorator is applied. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. (but used by the decorator) + shared_state : Any + Shared state. Not used. + + """ + + all_layers_saved = check_if_all_layers_saved(self) + if all_layers_saved: + return + + # Now we assumed that all layers are not saved to disk. + should_be_loaded = {"local_adata": [], "refit_adata": []} + if self.test_layers_to_load: + # In that case, we only have one layer to load + should_be_loaded["local_adata"].append("_mu_hat") + if self.save_cooks: + should_be_loaded["local_adata"].append("cooks") + + else: + # If we save cooks, then cooks should be loaded. if not, then not. + should_be_loaded["local_adata"] = list( + self.local_reference_adata.layers.keys() + ) + if not self.save_cooks: + if "cooks" in should_be_loaded["local_adata"]: + should_be_loaded["local_adata"].remove("cooks") + + for adata_name in {"local_adata", "refit_adata"}: + adata = getattr(self, adata_name) + should_be_loaded_adata = should_be_loaded[adata_name] + if adata is None: + assert len(should_be_loaded_adata) == 0 + else: + # It is only a subset because some sub layers can be created in the + # meantime. + + assert set(should_be_loaded_adata).issubset(set(adata.layers.keys())) + assert adata_name == "local_adata" # TODO check for refit + # Check the equality + for layer_name in adata.layers.keys(): + assert np.allclose( + self.local_reference_adata.layers[layer_name], + adata.layers[layer_name], + equal_nan=True, + ) + + assert np.allclose( + self.local_reference_adata.X, + self.local_adata.X, + ) + + def save_local_state(self, path: Path) -> None: + """Save the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to save the state. Automatically handled by subtrafl. + """ + state_to_save = { + "local_adata": self.local_adata, + "refit_adata": self.refit_adata, + "local_reference_adata": self.local_reference_adata, + } + with open(path, "wb") as file: + pkl.dump(state_to_save, file) + + def load_local_state(self, path: Path) -> Any: + """Load the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to load the state from. Automatically handled by + subtrafl. + """ + with open(path, "rb") as file: + state_to_load = pkl.load(file) + + self.local_adata = state_to_load["local_adata"] + self.refit_adata = state_to_load["refit_adata"] + self.local_reference_adata = state_to_load["local_reference_adata"] + return self + + @property + def num_round(self): + """Return the number of round in the strategy. + + TODO do something clever with this. + + Returns + ------- + int + Number of round in the strategy. + """ + return None + + +def check_if_all_layers_saved(self): + """ + + Parameters + ---------- + self + + Returns + ------- + + """ + if not hasattr(self, "save_layers_to_disk") or self.save_layers_to_disk: + # Check that all layers originally present are still present, + assert set(self.local_reference_adata.layers.keys()) == set( + self.local_adata.layers.keys() + ) + # and that the layers that were not present are still not present + for layer_name in self.local_reference_adata.layers.keys(): + assert np.allclose( + self.local_reference_adata.layers[layer_name], + self.local_adata.layers[layer_name], + equal_nan=True, + ) + return True + return False diff --git a/tests/unit_tests/layers/opener/__init__.py b/tests/unit_tests/layers/opener/__init__.py new file mode 100644 index 0000000..a4f1b19 --- /dev/null +++ b/tests/unit_tests/layers/opener/__init__.py @@ -0,0 +1 @@ +"""An opener to test layers simplification.""" diff --git a/tests/unit_tests/layers/opener/description.md b/tests/unit_tests/layers/opener/description.md new file mode 100644 index 0000000..4831e7a --- /dev/null +++ b/tests/unit_tests/layers/opener/description.md @@ -0,0 +1 @@ +Test opener diff --git a/tests/unit_tests/layers/opener/opener.py b/tests/unit_tests/layers/opener/opener.py new file mode 100644 index 0000000..4409f53 --- /dev/null +++ b/tests/unit_tests/layers/opener/opener.py @@ -0,0 +1,47 @@ +import pathlib + +import anndata as ad +import pandas as pd +import substratools as tools + + +class SimpleOpener(tools.Opener): + """Opener class for testing purposes. + + Creates an AnnData object from a path containing a counts_data.csv. + """ + + def fake_data(self, n_samples=None): + """Create a fake AnnData object for testing purposes. + + Parameters + ---------- + n_samples : int + Number of samples to generate. If None, generate 100 samples. + + Returns + ------- + AnnData + An AnnData object with fake counts and metadata. + """ + pass + + def get_data(self, folders): + """get the data + + Parameters + ---------- + folders : list + List of paths to the dataset folders, whose first element should contain a + counts_data.csv and a metadata.csv file. + + Returns + ------- + AnnData + An AnnData object containing the counts and metadata loaded for the FL pipe. + """ + data_path = pathlib.Path(folders[0]).resolve() + counts_data = pd.read_csv(data_path / "counts_data.csv", index_col=0) + + adata = ad.AnnData(X=counts_data) + return adata diff --git a/tests/unit_tests/layers/test_reconstruct_layers.py b/tests/unit_tests/layers/test_reconstruct_layers.py new file mode 100644 index 0000000..47fd25f --- /dev/null +++ b/tests/unit_tests/layers/test_reconstruct_layers.py @@ -0,0 +1,274 @@ +"""Implement the test of simple layers.""" +import os +import tempfile +from itertools import product +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from substra.sdk.schemas import DataSampleSpec +from substra.sdk.schemas import DatasetSpec +from substra.sdk.schemas import Permissions +from substrafl.experiment import simulate_experiment +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode + +from fedpydeseq2.substra_utils.utils import get_client +from tests.unit_tests.layers.layers_tester import SimpleLayersTester + + +def get_data( + n_centers: int = 2, min_n_obs: int = 80, max_n_obs: int = 120, n_vars: int = 20 +) -> tuple[list, list]: + """Get data given hyperparameters. + + For each center, generate a random number of observations + between min_n_obs and max_n_obs. + + Parameters + ---------- + n_centers : int + The number of centers. + + min_n_obs : int + The minimum number of observations. + + max_n_obs : int + The maximum number of observations. + + n_vars : int + The number of variables. Corresponds to the number of genes in + DGEA. + + Returns + ------- + list_counts : list + A list of count matrices. + + list_obs_names : list + A list of lists of observation names. + + """ + + list_counts = [] + list_obs_names = [] + n_obs_offset = 0 + for _ in range(n_centers): + n_obs = np.random.randint(min_n_obs, max_n_obs) + list_counts.append(np.random.randint(0, 1000, size=(n_obs, n_vars))) + obs_names = [f"sample_{i}" for i in range(n_obs_offset, n_obs_offset + n_obs)] + n_obs_offset += n_obs + list_obs_names.append(obs_names) + + return list_counts, list_obs_names + + +@pytest.mark.parametrize( + "num_row_values_equal_n_params, add_cooks, save_cooks, " + "has_save_layers_to_disk_attribute, save_layers_to_disk, " + "has_layers_to_save_on_disk_attribute, test_layers_to_save_on_disk_attribute, " + "test_layers_to_load, test_layers_to_save", + [ + parameters + for parameters in product( + [ + True, + False, + ], + [ + True, + False, + ], + [ + True, + False, + ], + [ + True, + False, + ], + [ + True, + False, + ], + [ + True, + False, + ], + [ + True, + False, + ], + [ + True, + False, + ], + [ + True, + False, + ], + ) + if parameters[1] or not parameters[2] + ], +) +def test_reconstruct_adatas_decorator( + num_row_values_equal_n_params: bool, + add_cooks: bool, + save_cooks: bool, + has_save_layers_to_disk_attribute: bool, + save_layers_to_disk: bool, + has_layers_to_save_on_disk_attribute: bool, + test_layers_to_save_on_disk_attribute: bool, + test_layers_to_save: bool, + test_layers_to_load: bool, + n_centers: int = 2, + n_params: bool = 5, + min_n_obs: int = 80, + max_n_obs: int = 120, + n_vars: int = 20, +): + """Test the reconstruct_adatas decorator. + + Parameters + ---------- + num_row_values_equal_n_params : bool + Whether the number of values taken by the design is equal to + the number of parameters. + If False, the number of row values is equal to the number of parameters + 1. + + add_cooks : bool + Whether to add the cooks layer. Defaults to False. + + + + has_save_layers_to_disk_attribute : bool + Whether the strategy has the save_layers_to_disk attribute. + + save_layers_to_disk : bool + Whether to save the layers to disk. + + has_layers_to_save_load_attribute : bool + Whether the strategy has the layers_to_save and layers_to_load attribute. + + test_layers_to_load : bool + Whether to test the layers to load. + + test_layers_to_save : bool + Whether to test the layers to save. + + n_centers : int + The number of centers. + + n_params : int + The number of parameters. + + min_n_obs : int + The minimum number of observations. + + max_n_obs : int + The maximum number of observations. + + n_vars : int + The number of variables. Corresponds to the number of genes in + DGEA. + + """ + + strategy = SimpleLayersTester( + num_row_values_equal_n_params=num_row_values_equal_n_params, + add_cooks=add_cooks, + n_params=n_params, + save_layers_to_disk=save_layers_to_disk, + has_save_layers_to_disk_attribute=has_save_layers_to_disk_attribute, + has_layers_to_save_on_disk_attribute=has_layers_to_save_on_disk_attribute, + test_layers_to_save_on_disk_attribute=test_layers_to_save_on_disk_attribute, + save_cooks=save_cooks, + test_layers_to_load=test_layers_to_load, + test_layers_to_save=test_layers_to_save, + ) + list_counts, list_obs_names = get_data(n_centers, min_n_obs, max_n_obs, n_vars) + n_centers = len(list_counts) + n_clients = n_centers + 1 + backend = "subprocess" + exp_path = Path(tempfile.mkdtemp()) + assets_directory = Path(__file__).parent / "opener" + + clients_ = [get_client(backend_type=backend) for _ in range(n_clients)] + + clients = { + client.organization_info().organization_id: client for client in clients_ + } + + # Store organization IDs + all_orgs_id = list(clients.keys()) + algo_org_id = all_orgs_id[0] # Algo provider is defined as the first organization. + data_providers_ids = all_orgs_id[ + 1: + ] # Data providers orgs are the remaining organizations. + + dataset_keys = {} + train_datasample_keys = {} + list_df_counts = [] + + for i, org_id in enumerate(data_providers_ids): + client = clients[org_id] + permissions_dataset = Permissions(public=True, authorized_ids=all_orgs_id) + dataset = DatasetSpec( + name="Test", + type="csv", + data_opener=assets_directory / "opener.py", + description=assets_directory / "description.md", + permissions=permissions_dataset, + logs_permission=permissions_dataset, + ) + print( + f"Adding dataset to client " + f"{str(client.organization_info().organization_id)}" + ) + dataset_keys[org_id] = client.add_dataset(dataset) + print("Dataset added. Key: ", dataset_keys[org_id]) + + os.makedirs(exp_path / f"dataset_{org_id}", exist_ok=True) + n_genes = list_counts[i].shape[1] + columns = [f"gene_{i}" for i in range(n_genes)] + + # set seed + + df_counts = pd.DataFrame( + list_counts[i], + index=list_obs_names[i], + columns=columns, + ) + list_df_counts.append(df_counts) + df_counts.to_csv(exp_path / f"dataset_{org_id}" / "counts_data.csv") + + data_sample = DataSampleSpec( + data_manager_keys=[dataset_keys[org_id]], + path=exp_path / f"dataset_{org_id}", + ) + train_datasample_keys[org_id] = client.add_data_sample(data_sample) + + aggregation_node = AggregationNode(algo_org_id) + + train_data_nodes = [] + + for org_id in data_providers_ids: + # Create the Train Data Node (or training task) and save it in a list + train_data_node = TrainDataNode( + organization_id=org_id, + data_manager_key=dataset_keys[org_id], + data_sample_keys=[train_datasample_keys[org_id]], + ) + train_data_nodes.append(train_data_node) + + simulate_experiment( + client=clients[algo_org_id], + strategy=strategy, + train_data_nodes=train_data_nodes, + evaluation_strategy=None, + aggregation_node=aggregation_node, + clean_models=True, + num_rounds=strategy.num_round, + experiment_folder=exp_path, + ) diff --git a/tests/unit_tests/layers/test_utils.py b/tests/unit_tests/layers/test_utils.py new file mode 100644 index 0000000..d135b10 --- /dev/null +++ b/tests/unit_tests/layers/test_utils.py @@ -0,0 +1,135 @@ +import numpy as np + +from fedpydeseq2.core.utils.layers.build_layers.hat_diagonals import make_hat_diag_batch +from fedpydeseq2.core.utils.layers.build_layers.mu_layer import make_mu_batch +from fedpydeseq2.core.utils.layers.cooks_layer import make_hat_matrix_summands_batch + + +def test_make_hat_matrix_summands_batch(): + """Test the function make_hat_matrix_summands_batch. + + This test checks that the function returns the correct output shape given + input shapes + of size (3, 2), (3,), (5, 2), (5,), and a scalar. + """ + # Create fake data + design_matrix = np.array([[1, 2], [3, 4], [5, 6]]) + size_factors = np.array([1, 2, 3]) + beta = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + dispersions = np.array([1, 2, 3, 4, 5]) + min_mu = 0.1 + + H = make_hat_matrix_summands_batch( + design_matrix, size_factors, beta, dispersions, min_mu + ) + + # Check that the outputs are correct + + assert np.allclose(H, H.transpose(0, 2, 1)) + + assert H.shape == (5, 2, 2) + + +def test_make_hat_matrix_summands_batch_single_dim(): + """Test the function make_hat_matrix_summands_batch. + + This test checks the border case where the design matrix has only one row, and + there is only one gene. + """ + design_matrix = np.array([[1]]) + size_factors = np.array([1]) + beta = np.array([[1]]) + dispersions = np.array([1]) + min_mu = 0.1 + + H = make_hat_matrix_summands_batch( + design_matrix, size_factors, beta, dispersions, min_mu + ) + + assert H.shape == (1, 1, 1) + + +def test_make_mu_batch(): + """Test the function make_mu_batch. + + This test checks that the function returns the correct output shapes + given input shapes of size (5, 2), (3, 2), and (3,). + """ + beta = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + design_matrix = np.array([[1, 2], [3, 4], [5, 6]]) + size_factors = np.array([1, 2, 3]) + + mu = make_mu_batch( + beta, + design_matrix, + size_factors, + ) + + assert mu.shape == (3, 5) + + +def test_make_mu_batch_single_dim(): + """Test the function make_irls_mu_and_diag_batch. + + This test checks the border case where the design matrix has only one row, and + there is only one gene. + """ + beta = np.array([[1]]) + design_matrix = np.array([[1]]) + size_factors = np.array([1]) + + mu = make_mu_batch( + beta, + design_matrix, + size_factors, + ) + + assert mu.shape == (1, 1) + + +def test_make_hat_diag_batch(): + """Test the function make_hat_diag_batch. + + This test checks that the function returns the correct output shapes + given input shapes of size (3, 2), (3,), (5, 2), (5,), and a scalar. + """ + beta = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + global_hat_matrix_inv = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + [[9, 10], [11, 12]], + [[13, 14], [15, 16]], + [[17, 18], [19, 20]], + ] + ) + design_matrix = np.array([[1, 2], [3, 4], [5, 6]]) + size_factors = np.array([1, 2, 3]) + dispersions = np.array([1, 2, 3, 4, 5]) + min_mu = 0.1 + + H = make_hat_diag_batch( + beta, global_hat_matrix_inv, design_matrix, size_factors, dispersions, min_mu + ) + + assert H.shape == (5, 3) + + +def test_make_hat_diag_batch_single_dim(): + """Test the function make_hat_diag_batch. + + This test checks the border case where the design matrix has only one row, and + there is only one gene. + """ + beta = np.array([[1]]) + global_hat_matrix_inv = np.array([[[1]]]) + design_matrix = np.array([[1]]) + size_factors = np.array([1]) + dispersions = np.array([1]) + min_mu = 0.1 + + H = make_hat_diag_batch( + beta, global_hat_matrix_inv, design_matrix, size_factors, dispersions, min_mu + ) + + assert H.shape == (1, 1) diff --git a/tests/unit_tests/layers/utils.py b/tests/unit_tests/layers/utils.py new file mode 100644 index 0000000..0e034db --- /dev/null +++ b/tests/unit_tests/layers/utils.py @@ -0,0 +1,112 @@ +"""Module to test create dummy AnnData with layers.""" + +import anndata as ad +import numpy as np +import pandas as pd + +from fedpydeseq2.core.utils.layers import set_mu_layer +from fedpydeseq2.core.utils.layers.build_layers import set_fit_lin_mu_hat +from fedpydeseq2.core.utils.layers.build_layers import set_mu_hat_layer +from fedpydeseq2.core.utils.layers.build_layers import set_normed_counts +from fedpydeseq2.core.utils.layers.build_layers import set_sqerror_layer +from fedpydeseq2.core.utils.layers.build_layers import set_y_hat + + +def create_dummy_adata_with_layers( + data_from_opener: ad.AnnData, + num_row_values_equal_n_params=False, + add_cooks=False, + n_params=5, +) -> ad.AnnData: + """Create a dummy AnnData object with the necessary layers for testing. + + Parameters + ---------- + data_from_opener : ad.AnnData + The data from the opener. + + num_row_values_equal_n_params : bool, optional + Whether the number of values taken by the design is equal to + the number of parameters. + If False, the number of row values is equal to the number of parameters + 1. + Defaults to False. + + add_cooks : bool, optional + Whether to add the cooks layer. Defaults to False. + + n_params : int, optional + Number of parameters. Defaults to 5. + + Returns + ------- + ad.AnnData + The dummy AnnData object. + """ + n_obs, n_vars = data_from_opener.X.shape + + adata = ad.AnnData( + X=data_from_opener.X, + obs=data_from_opener.obs, + var=data_from_opener.var, + ) + + # We need to have a "cells" obs field + adata.obs["cells"] = np.random.choice(["A", "B", "C", "D", "E"], size=n_obs) + + # We need to create the following obsm fields + # - design_matrix + # - size_factors + adata.obsm["design_matrix"] = pd.DataFrame( + index=adata.obs_names, + data=np.random.randint(low=0, high=2, size=(n_obs, n_params)), + ) + adata.obsm["size_factors"] = np.random.rand(n_obs) + + # We need to have the following uns fields + # - n_params + adata.uns["n_params"] = n_params + # - num_replicates + adata.uns["num_replicates"] = pd.DataFrame( + index=np.arange(n_params if num_row_values_equal_n_params else n_params + 1), + data=np.random.rand( + n_params if num_row_values_equal_n_params else n_params + 1 + ), + ) + + # We need to create the following varm fields + # - _beta_rough_dispersions + adata.varm["_beta_rough_dispersions"] = np.random.rand(n_vars, n_params) + adata.varm["non_zero"] = np.random.rand(n_vars) > 0.2 + adata.varm["_mu_hat_LFC"] = pd.DataFrame( + index=adata.var_names, data=np.random.rand(n_vars, n_params) + ) + adata.varm["LFC"] = pd.DataFrame( + index=adata.var_names, data=np.random.rand(n_vars, n_params) + ) + adata.varm["cell_means"] = pd.DataFrame( + index=adata.var_names, + columns=["A", "B", "C", "D", "E"], + data=np.random.rand(n_vars, 5), + ) + + set_normed_counts(adata) + set_mu_layer( + local_adata=adata, + lfc_param_name="LFC", + mu_param_name="_mu_LFC", + ) + set_mu_layer( + local_adata=adata, + lfc_param_name="_mu_hat_LFC", + mu_param_name="_irls_mu_hat", + ) + + set_sqerror_layer(adata) + set_y_hat(adata) + set_fit_lin_mu_hat(adata) + set_mu_hat_layer(adata) + + if add_cooks: + adata.layers["cooks"] = np.random.rand(n_obs, n_vars) + + return adata diff --git a/tests/unit_tests/unit_test_helpers/__init__.py b/tests/unit_tests/unit_test_helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/unit_test_helpers/levels.py b/tests/unit_tests/unit_test_helpers/levels.py new file mode 100644 index 0000000..841efef --- /dev/null +++ b/tests/unit_tests/unit_test_helpers/levels.py @@ -0,0 +1,72 @@ +DEFAULT_PYDESEQ2_REF_LEVELS = { + "stage": "Advanced", + "gender": "female", +} + + +def make_reference_and_fl_ref_levels( + design_factors: list[str], + continuous_factors: list[str] | None = None, + ref_levels: dict[str, str] | None = None, + reference_dds_ref_level: tuple[str, str] | None = None, +) -> tuple[dict[str, str] | None, tuple[str, str] | None]: + """Function to make reference levels for PyDESeq2 and FedPyDESeq2. + + The goal of this function is to enforce that the design matrices will be + comparable between PyDESeq2 and FedPyDESeq2. This is done by ensuring that + the reference levels are the same for both packages. + + Parameters + ---------- + design_factors : list[str] + List of factors in the design matrix. + continuous_factors : list[str] or None + List of continuous factors in the design matrix, by default None. + ref_levels : dict[str, str] or None + Reference levels for the factors in the design matrix, by default None. + These are the reference levels used as arguments in the + FedPyDESeq2Strategy class. + reference_dds_ref_level : tuple[str, str] or None + Reference level for the factor in the design matrix, by default None. + This is the reference level used as an argument in the DESeqDataSet + class from the pydeseq2 package. + + + Returns + ------- + complete_ref_levels : Optional[dict[str, str]] + Reference levels for the factors in the design matrix for the + FedPyDESeq2 package. + reference_dds_ref_level : Optional[tuple[str, str]] + Reference level for the factor in the design matrix for the + PyDESeq2 package. + + """ + categorical_factors = ( + design_factors + if continuous_factors is None + else [factor for factor in design_factors if factor not in continuous_factors] + ) + complete_ref_levels = { + factor: level + for factor, level in DEFAULT_PYDESEQ2_REF_LEVELS.items() + if factor in categorical_factors + } + if ref_levels is not None: + if len(ref_levels) > 1: + print( + "Warning: only one reference level is supported when comparing with " + "PyDESeq2. The first reference level will be used." + ) + ref_factor, ref_level = next(iter(ref_levels.items())) + if reference_dds_ref_level is not None: + assert ref_factor == reference_dds_ref_level[0] + assert ref_level == reference_dds_ref_level[1] + else: + reference_dds_ref_level = (ref_factor, ref_level) + complete_ref_levels[ref_factor] = ref_level + elif reference_dds_ref_level is not None: + ref_factor, ref_level = reference_dds_ref_level + complete_ref_levels[ref_factor] = ref_level + + return complete_ref_levels, reference_dds_ref_level diff --git a/tests/unit_tests/unit_test_helpers/pass_on_first_shared_state.py b/tests/unit_tests/unit_test_helpers/pass_on_first_shared_state.py new file mode 100644 index 0000000..cbdfa7b --- /dev/null +++ b/tests/unit_tests/unit_test_helpers/pass_on_first_shared_state.py @@ -0,0 +1,31 @@ +"""Module to implement the passing of the first shared state. + +# TODO remove after all savings have been factored out, if not needed anymore. +""" +from substrafl.remote import remote + +from fedpydeseq2.core.utils.logging import log_remote + + +class AggPassOnFirstSharedState: + """Mixin to pass on the first shared state.""" + + @remote + @log_remote + def pass_on_shared_state(self, shared_states: list[dict]) -> dict: + """Pass on the shared state. + + This method simply returns the first shared state. + + Parameters + ---------- + shared_states : list + List of shared states. + + Returns + ------- + shared_state : dict + The shared state to be passed on. + + """ + return shared_states[0] diff --git a/tests/unit_tests/unit_test_helpers/set_local_reference.py b/tests/unit_tests/unit_test_helpers/set_local_reference.py new file mode 100644 index 0000000..d495e7f --- /dev/null +++ b/tests/unit_tests/unit_test_helpers/set_local_reference.py @@ -0,0 +1,112 @@ +"""Module to implement the step to set the local reference for the DESeq2Strategy.""" +import pickle as pkl +from pathlib import Path +from typing import Any + +import anndata as ad +from pydeseq2.dds import DeseqDataSet +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils import local_step +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data + + +class SetLocalReference: + """Mixin to set the local reference for the DESeq2Strategy.""" + + reference_data_path: str | Path | None + local_reference_dds: DeseqDataSet | None + reference_dds_name: str + + def set_local_reference_dataset( + self, + train_data_nodes, + aggregation_node, + local_states, + round_idx, + clean_models, + ): + """Set the local reference DeseqDataset. + + This function restricts the global reference DeseqDataset (assumed to be + constructed from the + pooled data and save in the right subdirectory, see + `fedpydeseq2.data.tcga_setup` module + for more details) to the data in the local dataset. + + It then sets the local reference DeseqDataset as an attribute of the strategy. + + Parameters + ---------- + train_data_nodes : list + List of TrainDataNode. + aggregation_node : AggregationNode + Aggregation Node. + local_states : dict + Local states. Required to propagate intermediate results. + round_idx : int + Index of the current round. + clean_models : bool + Whether to clean the models after the computation. + + Returns + ------- + local_states : dict + Local states. Required to propagate intermediate results. + shared_states : Any + Empty shared state. + round_idx : int + Index of the current round. + """ + local_states, shared_states, round_idx = local_step( + local_method=self.set_local_reference_dataset_remote, + train_data_nodes=train_data_nodes, + output_local_states=local_states, + input_local_states=None, + input_shared_state=None, + aggregation_id=aggregation_node.organization_id, + description="Setting the reference local dds object, " + "if given a reference_data_path.", + round_idx=round_idx, + clean_models=clean_models, + ) + + return local_states, shared_states, round_idx + + @remote_data + @log_remote_data + @reconstruct_adatas + def set_local_reference_dataset_remote( + self, data_from_opener: ad.AnnData, shared_state: Any + ): + """Set the local reference DeseqDataset. + + This function restricts the global reference DeseqDataset (assumed to be + constructed from the + pooled data and save in the right subdirectory, see + `fedpydeseq2.data.tcga_setup` module + for more details) to the data in the local dataset. + + It then sets the local reference DeseqDataset as an attribute of the strategy. + + Parameters + ---------- + data_from_opener : ad.AnnData + The local AnnData. Used here to access the center_id. + shared_state : Any + Not used here. + + """ + if self.reference_data_path is not None: + reference_data_path = Path(self.reference_data_path).resolve() + center_id = data_from_opener.uns["center_id"] + # Get the + path = ( + reference_data_path + / f"center_{center_id}" + / f"{self.reference_dds_name}.pkl" + ) + with open(path, "rb") as file: + local_reference_dds = pkl.load(file) + self.local_reference_dds = local_reference_dds diff --git a/tests/unit_tests/unit_test_helpers/unit_tester.py b/tests/unit_tests/unit_test_helpers/unit_tester.py new file mode 100644 index 0000000..11e1804 --- /dev/null +++ b/tests/unit_tests/unit_test_helpers/unit_tester.py @@ -0,0 +1,258 @@ +import pickle as pkl +from abc import abstractmethod +from pathlib import Path +from typing import Any + +import anndata as ad +from fedpydeseq2_datasets.utils import get_ground_truth_dds_name +from pydeseq2.dds import DeseqDataSet +from substrafl import ComputePlanBuilder +from substrafl.nodes import AggregationNode +from substrafl.nodes import TrainDataNode +from substrafl.remote import remote_data + +from fedpydeseq2.core.utils.layers import reconstruct_adatas +from fedpydeseq2.core.utils.logging import log_remote_data +from fedpydeseq2.core.utils.logging import log_save_local_state +from tests.unit_tests.unit_test_helpers.set_local_reference import SetLocalReference + + +class UnitTester( + ComputePlanBuilder, + SetLocalReference, +): + """A base semi-abstract class to implement unit tests for DESea2 steps. + + Parameters + ---------- + design_factors : str or list + Name of the columns of metadata to be used as design variables. + If you are using categorical and continuous factors, you must put + all of them here. + + ref_levels : dict or None + An optional dictionary of the form ``{"factor": "test_level"}`` + specifying for each factor the reference (control) level against which + we're testing, e.g. ``{"condition", "A"}``. Factors that are left out + will be assigned random reference levels. (default: ``None``). + + continuous_factors : list or None + An optional list of continuous (as opposed to categorical) factors. Any factor + not in ``continuous_factors`` will be considered categorical + (default: ``None``). + + contrast : list or None + A list of three strings, in the following format: + ``['variable_of_interest', 'tested_level', 'ref_level']``. + Names must correspond to the metadata data passed to the DeseqDataSet. + E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B' compared + to 'condition A'. + For continuous variables, the last two strings should be left empty, e.g. + ``['measurement', '', ''].`` + If None, the last variable from the design matrix is chosen + as the variable of interest, and the reference level is picked alphabetically. + (default: ``None``). + + reference_data_path : str or Path + The path to the reference data. This is used to build the reference + DeseqDataSet. This is only used for testing purposes, and should not be + used in a real-world scenario. + + reference_dds_ref_level : tuple por None + The reference level of the reference DeseqDataSet. This is used to build the + reference DeseqDataSet. This is only used for testing purposes, and should not + be used in a real-world scenario. + + refit_cooks: bool + Whether to refit the model with the Cook's distance. (default: ``False``). + + joblib_backend : str + The joblib backend to use for parallelization. (default: ``"loky"``). + + save_layers_to_disk : bool + Whether to save the layers to disk. (default: ``False``). + + + The log level to use for the substrafl logger. (default: ``logging.DEBUG``). + + """ + + def __init__( + self, + design_factors: str | list[str], + ref_levels: dict[str, str] | None = None, + continuous_factors: list[str] | None = None, + contrast: list[str] | None = None, + reference_data_path: str | Path | None = None, + reference_dds_ref_level: tuple[str, ...] | None = None, + refit_cooks: bool = False, + joblib_backend: str = "loky", + save_layers_to_disk: bool = False, + *args, + **kwargs, + ): + # Add all arguments to super init so that they can be retrieved by nodes. + super().__init__( + design_factors=design_factors, + ref_levels=ref_levels, + continuous_factors=continuous_factors, + contrast=contrast, + reference_data_path=reference_data_path, + reference_dds_ref_level=reference_dds_ref_level, + refit_cooks=refit_cooks, + joblib_backend=joblib_backend, + save_layers_to_disk=save_layers_to_disk, + ) + + #### Define quantities to set the design #### + + # Convert design_factors to list if a single string was provided. + self.design_factors = ( + [design_factors] if isinstance(design_factors, str) else design_factors + ) + + self.ref_levels = ref_levels + self.continuous_factors = continuous_factors + + if self.continuous_factors is not None: + self.categorical_factors = [ + factor + for factor in self.design_factors + if factor not in self.continuous_factors + ] + else: + self.categorical_factors = self.design_factors + + self.contrast = contrast + + #### Set attributes to be registered / saved later on #### + self.results: dict = {} + self.local_adata: ad.AnnData = None + self.refit_adata: ad.AnnData = None + + #### Used only if we want the reference + self.reference_data_path = reference_data_path + self.reference_dds_name = get_ground_truth_dds_name( + reference_dds_ref_level, refit_cooks=refit_cooks + ) + self.local_reference_dds: DeseqDataSet | None = None + + #### Joblib parameters #### + self.joblib_backend = joblib_backend + + #### Layers paramters + self.layers_to_load: dict[str, list[str] | None] | None = None + self.layers_to_save_on_disk: dict[str, list[str] | None] | None = None + + #### Save layers to disk + self.save_layers_to_disk = save_layers_to_disk + + @abstractmethod + def build_compute_plan( + self, + train_data_nodes: list[TrainDataNode], + aggregation_node: AggregationNode, + evaluation_strategy=None, + num_rounds=None, + clean_models=True, + ): + """Build the computation graph to run a FedDESeq2 pipe. + + Parameters + ---------- + train_data_nodes : list[TrainDataNode] + List of the train nodes. + aggregation_node : AggregationNode + Aggregation node. + evaluation_strategy : EvaluationStrategy + Not used. + num_rounds : int + Number of rounds. Not used. + clean_models : bool + Whether to clean the models after the computation. (default: ``False``). + """ + + @remote_data + @log_remote_data + @reconstruct_adatas + def init_local_states( + self, data_from_opener: ad.AnnData, shared_state: Any + ) -> None: + """Copy the reference dds to the local state. + + If necessary, to overwrite in child classes to add relevant local states. + + Parameters + ---------- + data_from_opener : ad.AnnData + AnnData returned by the opener. Not used. + shared_state : Any + Shared state. Not used. + """ + + self.local_adata = self.local_reference_dds.copy() + # Delete the unused "_mu_hat" layer as it will raise errors + del self.local_adata.layers["_mu_hat"] + # This field is not saved in pydeseq2 but used in fedpyseq2 + self.local_adata.uns["n_params"] = self.local_adata.obsm["design_matrix"].shape[ + 1 + ] + + @log_save_local_state + def save_local_state(self, path: Path) -> None: + """Save the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to save the state. Automatically handled by subtrafl. + """ + state_to_save = { + "results": self.results, + "local_adata": self.local_adata, + "refit_adata": self.refit_adata, + "local_reference_dds": self.local_reference_dds, + } + with open(path, "wb") as file: + pkl.dump(state_to_save, file) + + def load_local_state(self, path: Path) -> Any: + """Load the local state of the strategy. + + Parameters + ---------- + path : Path + Path to the file where to load the state from. Automatically handled by + subtrafl. + """ + with open(path, "rb") as file: + state_to_load = pkl.load(file) + + self.results = state_to_load["results"] + self.local_adata = state_to_load["local_adata"] + self.refit_adata = state_to_load["refit_adata"] + self.local_reference_dds = state_to_load["local_reference_dds"] + return self + + @property + def num_round(self): + """Return the number of round in the strategy. + + TODO do something clever with this. + + Returns + ------- + int + Number of round in the strategy. + """ + return None + + def get_result(self): + """Return the statistic computed. + + Returns + ------- + dict + The global statistics. + """ + return self.results diff --git a/tests/unit_tests/utils/test_mle.py b/tests/unit_tests/utils/test_mle.py new file mode 100644 index 0000000..55d34ad --- /dev/null +++ b/tests/unit_tests/utils/test_mle.py @@ -0,0 +1,30 @@ +from itertools import product + +import numpy as np +import pytest + +from fedpydeseq2.core.utils.mle import global_grid_cr_loss + + +@pytest.mark.parametrize( + "n_genes, grid_length, n_params, percentage_nan", + product([50, 100, 500], [5, 20, 50], [2, 3, 5], [0.1, 0.2, 0.9, 1.0]), +) +def test_global_grid_cr_loss_with_nans(n_genes, grid_length, n_params, percentage_nan): + """ + Test the global_grid_cr_loss function with NaNs in the input arrays. + """ + np.random.seed(seed=42) + n_genes, grid_length, n_params = 10, 15, 2 + percentage_nan = 0.1 + + nll = np.random.uniform(size=(n_genes, grid_length)) + mask_nan = np.random.uniform(size=(n_genes, grid_length)) < percentage_nan + nll[mask_nan] = np.nan + cr_grid = np.random.uniform(size=(n_genes, grid_length, n_params, n_params)) + mask_nan_cr_grid = np.random.uniform(size=(n_genes, grid_length)) < percentage_nan + cr_grid[mask_nan_cr_grid] = np.nan + + expected = nll + 0.5 * np.linalg.slogdet(cr_grid)[1] + true_result = global_grid_cr_loss(nll, cr_grid) + assert np.array_equal(true_result, expected, equal_nan=True)