From e693ce6974c90e5704939a90a53c9789ecd26591 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Thu, 30 May 2024 18:06:44 +0200 Subject: [PATCH] parallelized the computation of weighted_* quantities with MPI; updated the relevant tests; updated pyproject.toml and setup.py to include the path for mpi4py headers automatically; github test workflow will now run on every push --- .github/workflows/tests.yaml | 6 +- brahmap/__init__.py | 14 +- brahmap/_extensions/compute_weights.cpp | 490 ++++++++++++---------- brahmap/_extensions/mpi_utils.hpp | 75 ++++ brahmap/utilities/__init__.py | 3 +- brahmap/utilities/mpi.py | 16 +- brahmap/utilities/process_time_samples.py | 100 +++-- pyproject.toml | 8 +- setup.py | 3 +- tests/helper_ComputeWeights.py | 18 + tests/helper_ProcessTimeSamples.py | 11 + tests/test_BlkDiagPrecondLO.py | 17 +- tests/test_BlkDiagPrecondLO_tools_cpp.py | 11 +- tests/test_PointingLO.py | 7 +- tests/test_ProcessTimeSamples.py | 35 +- tests/test_compute_weights_cpp.py | 19 +- tests/test_repixelization_cpp.py | 13 +- 17 files changed, 553 insertions(+), 293 deletions(-) create mode 100644 brahmap/_extensions/mpi_utils.hpp diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c17299d..749a3aa 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -4,10 +4,8 @@ # setting up compiler: name: Tests -on: - push: - branches: - - main +on: [push] + jobs: build: runs-on: ${{ matrix.os }} diff --git a/brahmap/__init__.py b/brahmap/__init__.py index 31b116b..12e048b 100644 --- a/brahmap/__init__.py +++ b/brahmap/__init__.py @@ -1,6 +1,15 @@ -from . import interfaces, utilities, linop, mapmakers, _extensions +import mpi4py -from .utilities import Initialize +mpi4py.rc.initialize = False + +from mpi4py import MPI # noqa: E402 + +if MPI.Is_initialized() is False: + MPI.Init_thread(required=MPI.THREAD_FUNNELED) + +from . import interfaces, utilities, linop, mapmakers, _extensions # noqa: E402 + +from .utilities import Initialize, MPI_RAISE_EXCEPTION # noqa: E402 bMPI = None @@ -11,4 +20,5 @@ "mapmakers", "_extensions", "Initialize", + "MPI_RAISE_EXCEPTION", ] diff --git a/brahmap/_extensions/compute_weights.cpp b/brahmap/_extensions/compute_weights.cpp index a0fa0c5..54a4396 100644 --- a/brahmap/_extensions/compute_weights.cpp +++ b/brahmap/_extensions/compute_weights.cpp @@ -3,6 +3,10 @@ #include #include +#include + +#include "mpi_utils.hpp" + namespace py = pybind11; template @@ -15,7 +19,8 @@ dint compute_weights_pol_I( // dfloat *weighted_counts, // dint *observed_pixels, // dint *__old2new_pixel, // - bool *pixel_flag // + bool *pixel_flag, // + const MPI_Comm comm // ) { for (ssize_t idx = 0; idx < nsamples; ++idx) { @@ -25,6 +30,9 @@ dint compute_weights_pol_I( // weighted_counts[pixel] += weight; } // for + MPI_Allreduce(MPI_IN_PLACE, weighted_counts, npix, mpi_get_type(), + MPI_SUM, comm); + dint new_npix = 0; for (ssize_t idx = 0; idx < npix; ++idx) { if (weighted_counts[idx] > 0) { @@ -40,21 +48,21 @@ dint compute_weights_pol_I( // } // compute_weights_pol_I() template -void compute_weights_pol_QU( // - const ssize_t npix, // - const ssize_t nsamples, // - const dint *pointings, // - const bool *pointings_flag, // - const dfloat *noise_weights, // - const dfloat *pol_angles, // - dfloat *weighted_counts, // - dfloat *sin2phi, // - dfloat *cos2phi, // - dfloat *weighted_sin_sq, // - dfloat *weighted_cos_sq, // - dfloat *weighted_sincos, // - dfloat *one_over_determinant // - +void compute_weights_pol_QU( // + const ssize_t npix, // + const ssize_t nsamples, // + const dint *pointings, // + const bool *pointings_flag, // + const dfloat *noise_weights, // + const dfloat *pol_angles, // + dfloat *weighted_counts, // + dfloat *sin2phi, // + dfloat *cos2phi, // + dfloat *weighted_sin_sq, // + dfloat *weighted_cos_sq, // + dfloat *weighted_sincos, // + dfloat *one_over_determinant, // + const MPI_Comm comm // ) { for (ssize_t idx = 0; idx < nsamples; ++idx) { @@ -73,6 +81,15 @@ void compute_weights_pol_QU( // weighted_sincos[pixel] += weight * sin2phi[idx] * cos2phi[idx]; } // for + MPI_Allreduce(MPI_IN_PLACE, weighted_counts, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_sin_sq, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_cos_sq, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_sincos, npix, mpi_get_type(), + MPI_SUM, comm); + for (ssize_t idx = 0; idx < npix; ++idx) { dfloat determinant = weighted_sin_sq[idx] * weighted_cos_sq[idx] - weighted_sincos[idx] * weighted_sincos[idx]; @@ -85,22 +102,23 @@ void compute_weights_pol_QU( // } // compute_weights_pol_QU() template -void compute_weights_pol_IQU( // - const ssize_t npix, // - const ssize_t nsamples, // - const dint *pointings, // - const bool *pointings_flag, // - const dfloat *noise_weights, // - const dfloat *pol_angles, // - dfloat *weighted_counts, // - dfloat *sin2phi, // - dfloat *cos2phi, // - dfloat *weighted_sin_sq, // - dfloat *weighted_cos_sq, // - dfloat *weighted_sincos, // - dfloat *weighted_sin, // - dfloat *weighted_cos, // - dfloat *one_over_determinant // +void compute_weights_pol_IQU( // + const ssize_t npix, // + const ssize_t nsamples, // + const dint *pointings, // + const bool *pointings_flag, // + const dfloat *noise_weights, // + const dfloat *pol_angles, // + dfloat *weighted_counts, // + dfloat *sin2phi, // + dfloat *cos2phi, // + dfloat *weighted_sin_sq, // + dfloat *weighted_cos_sq, // + dfloat *weighted_sincos, // + dfloat *weighted_sin, // + dfloat *weighted_cos, // + dfloat *one_over_determinant, // + const MPI_Comm comm // ) { for (ssize_t idx = 0; idx < nsamples; ++idx) { @@ -121,6 +139,19 @@ void compute_weights_pol_IQU( // weighted_sincos[pixel] += weight * sin2phi[idx] * cos2phi[idx]; } // for + MPI_Allreduce(MPI_IN_PLACE, weighted_counts, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_sin, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_cos, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_sin_sq, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_cos_sq, npix, mpi_get_type(), + MPI_SUM, comm); + MPI_Allreduce(MPI_IN_PLACE, weighted_sincos, npix, mpi_get_type(), + MPI_SUM, comm); + for (ssize_t idx = 0; idx < npix; ++idx) { dfloat determinant = weighted_counts[idx] * weighted_cos_sq[idx] * weighted_sin_sq[idx] + @@ -181,7 +212,8 @@ std::function weighted_counts, // buffer_t observed_pixels, // buffer_t __old2new_pixel, // - buffer_t pixel_flag // + buffer_t pixel_flag, // + const py::object mpi4py_comm // )> numpy_bind_compute_weights_pol_I = // [](const ssize_t npix, // @@ -192,7 +224,8 @@ std::function dint { py::buffer_info pointings_info = pointings.request(); py::buffer_info pointings_flags_info = pointings_flag.request(); @@ -216,6 +249,9 @@ std::function(__old2new_pixel_info.ptr); bool *pixel_flag_ptr = reinterpret_cast(pixel_flag_info.ptr); + const MPI_Comm comm = + (reinterpret_cast(mpi4py_comm.ptr()))->ob_mpi; + dint new_npix = compute_weights_pol_I( // npix, // nsamples, // @@ -225,7 +261,8 @@ std::function class buffer_t, typename dint, typename dfloat> -std::function pointings, // - const buffer_t pointings_flag, // - const buffer_t noise_weights, // - const buffer_t pol_angles, // - buffer_t weighted_counts, // - buffer_t sin2phi, // - buffer_t cos2phi, // - buffer_t weighted_sin_sq, // - buffer_t weighted_cos_sq, // - buffer_t weighted_sincos, // - buffer_t one_over_determinant // +std::function pointings, // + const buffer_t pointings_flag, // + const buffer_t noise_weights, // + const buffer_t pol_angles, // + buffer_t weighted_counts, // + buffer_t sin2phi, // + buffer_t cos2phi, // + buffer_t weighted_sin_sq, // + buffer_t weighted_cos_sq, // + buffer_t weighted_sincos, // + buffer_t one_over_determinant, // + const py::object mpi4py_comm // )> numpy_bind_compute_weights_pol_QU = // [](const ssize_t npix, // @@ -261,7 +299,8 @@ std::function(one_over_determinant_info.ptr); - compute_weights_pol_QU( // - npix, // - nsamples, // - pointings_ptr, // - pointings_flag_ptr, // - noise_weights_ptr, // - pol_angles_ptr, // - weighted_counts_ptr, // - sin2phi_ptr, // - cos2phi_ptr, // - weighted_sin_sq_ptr, // - weighted_cos_sq_ptr, // - weighted_sincos_ptr, // - one_over_determinant_ptr // + const MPI_Comm comm = + (reinterpret_cast(mpi4py_comm.ptr())) + ->ob_mpi; + + compute_weights_pol_QU( // + npix, // + nsamples, // + pointings_ptr, // + pointings_flag_ptr, // + noise_weights_ptr, // + pol_angles_ptr, // + weighted_counts_ptr, // + sin2phi_ptr, // + cos2phi_ptr, // + weighted_sin_sq_ptr, // + weighted_cos_sq_ptr, // + weighted_sincos_ptr, // + one_over_determinant_ptr, // + comm // ); return; @@ -319,22 +363,23 @@ std::function class buffer_t, typename dint, typename dfloat> -std::function pointings, // - const buffer_t pointings_flag, // - const buffer_t noise_weights, // - const buffer_t pol_angles, // - buffer_t weighted_counts, // - buffer_t sin2phi, // - buffer_t cos2phi, // - buffer_t weighted_sin_sq, // - buffer_t weighted_cos_sq, // - buffer_t weighted_sincos, // - buffer_t weighted_sin, // - buffer_t weighted_cos, // - buffer_t one_over_determinant // +std::function pointings, // + const buffer_t pointings_flag, // + const buffer_t noise_weights, // + const buffer_t pol_angles, // + buffer_t weighted_counts, // + buffer_t sin2phi, // + buffer_t cos2phi, // + buffer_t weighted_sin_sq, // + buffer_t weighted_cos_sq, // + buffer_t weighted_sincos, // + buffer_t weighted_sin, // + buffer_t weighted_cos, // + buffer_t one_over_determinant, // + const py::object mpi4py_comm // )> numpy_bind_compute_weights_pol_IQU = // [](const ssize_t npix, // @@ -351,7 +396,8 @@ std::function(one_over_determinant_info.ptr); - compute_weights_pol_IQU( // - npix, // - nsamples, // - pointings_ptr, // - pointings_flag_ptr, // - noise_weights_ptr, // - pol_angles_ptr, // - weighted_counts_ptr, // - sin2phi_ptr, // - cos2phi_ptr, // - weighted_sin_sq_ptr, // - weighted_cos_sq_ptr, // - weighted_sincos_ptr, // - weighted_sin_ptr, // - weighted_cos_ptr, // - one_over_determinant_ptr // + const MPI_Comm comm = + (reinterpret_cast(mpi4py_comm.ptr())) + ->ob_mpi; + + compute_weights_pol_IQU( // + npix, // + nsamples, // + pointings_ptr, // + pointings_flag_ptr, // + noise_weights_ptr, // + pol_angles_ptr, // + weighted_counts_ptr, // + sin2phi_ptr, // + cos2phi_ptr, // + weighted_sin_sq_ptr, // + weighted_cos_sq_ptr, // + weighted_sincos_ptr, // + weighted_sin_ptr, // + weighted_cos_ptr, // + one_over_determinant_ptr, // + comm // ); return; @@ -426,7 +477,6 @@ std::function observed_pixels, // buffer_t __old2new_pixel, // buffer_t pixel_flag // - )> numpy_bind_get_pixel_mask_pol = // [](const int solver_type, // @@ -480,7 +530,8 @@ PYBIND11_MODULE(compute_weights, m) { py::arg("weighted_counts").noconvert(), // py::arg("observed_pixels").noconvert(), // py::arg("__old2new_pixel").noconvert(), // - py::arg("pixel_flag").noconvert() // + py::arg("pixel_flag").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_I", numpy_bind_compute_weights_pol_I, @@ -492,7 +543,8 @@ PYBIND11_MODULE(compute_weights, m) { py::arg("weighted_counts").noconvert(), // py::arg("observed_pixels").noconvert(), // py::arg("__old2new_pixel").noconvert(), // - py::arg("pixel_flag").noconvert() // + py::arg("pixel_flag").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_I", numpy_bind_compute_weights_pol_I, @@ -504,7 +556,8 @@ PYBIND11_MODULE(compute_weights, m) { py::arg("weighted_counts").noconvert(), // py::arg("observed_pixels").noconvert(), // py::arg("__old2new_pixel").noconvert(), // - py::arg("pixel_flag").noconvert() // + py::arg("pixel_flag").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_I", numpy_bind_compute_weights_pol_I, @@ -516,151 +569,160 @@ PYBIND11_MODULE(compute_weights, m) { py::arg("weighted_counts").noconvert(), // py::arg("observed_pixels").noconvert(), // py::arg("__old2new_pixel").noconvert(), // - py::arg("pixel_flag").noconvert() // + py::arg("pixel_flag").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_QU", numpy_bind_compute_weights_pol_QU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_QU", numpy_bind_compute_weights_pol_QU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_QU", numpy_bind_compute_weights_pol_QU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_QU", numpy_bind_compute_weights_pol_QU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_IQU", numpy_bind_compute_weights_pol_IQU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("weighted_sin").noconvert(), // - py::arg("weighted_cos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("weighted_sin").noconvert(), // + py::arg("weighted_cos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_IQU", numpy_bind_compute_weights_pol_IQU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("weighted_sin").noconvert(), // - py::arg("weighted_cos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("weighted_sin").noconvert(), // + py::arg("weighted_cos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_IQU", numpy_bind_compute_weights_pol_IQU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("weighted_sin").noconvert(), // - py::arg("weighted_cos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("weighted_sin").noconvert(), // + py::arg("weighted_cos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("compute_weights_pol_IQU", numpy_bind_compute_weights_pol_IQU, - py::arg("npix"), // - py::arg("nsamples"), // - py::arg("pointings").noconvert(), // - py::arg("pointings_flag").noconvert(), // - py::arg("noise_weights").noconvert(), // - py::arg("pol_angles").noconvert(), // - py::arg("weighted_counts").noconvert(), // - py::arg("sin2phi").noconvert(), // - py::arg("cos2phi").noconvert(), // - py::arg("weighted_sin_sq").noconvert(), // - py::arg("weighted_cos_sq").noconvert(), // - py::arg("weighted_sincos").noconvert(), // - py::arg("weighted_sin").noconvert(), // - py::arg("weighted_cos").noconvert(), // - py::arg("one_over_determinant").noconvert() // + py::arg("npix"), // + py::arg("nsamples"), // + py::arg("pointings").noconvert(), // + py::arg("pointings_flag").noconvert(), // + py::arg("noise_weights").noconvert(), // + py::arg("pol_angles").noconvert(), // + py::arg("weighted_counts").noconvert(), // + py::arg("sin2phi").noconvert(), // + py::arg("cos2phi").noconvert(), // + py::arg("weighted_sin_sq").noconvert(), // + py::arg("weighted_cos_sq").noconvert(), // + py::arg("weighted_sincos").noconvert(), // + py::arg("weighted_sin").noconvert(), // + py::arg("weighted_cos").noconvert(), // + py::arg("one_over_determinant").noconvert(), // + py::arg("comm").noconvert() // ); m.def("get_pixel_mask_pol", diff --git a/brahmap/_extensions/mpi_utils.hpp b/brahmap/_extensions/mpi_utils.hpp new file mode 100644 index 0000000..d2806fb --- /dev/null +++ b/brahmap/_extensions/mpi_utils.hpp @@ -0,0 +1,75 @@ +#ifndef _MPI_UTILS +#define _MPI_UTILS + +#include +#include + +// The following function is taken from +// +[[nodiscard]] constexpr MPI_Datatype mpi_get_type() noexcept { + + MPI_Datatype mpi_type = MPI_DATATYPE_NULL; + + if constexpr (std::is_same::value) { + mpi_type = MPI_CHAR; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_SIGNED_CHAR; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UNSIGNED_CHAR; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_WCHAR; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_SHORT; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UNSIGNED_SHORT; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_INT; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UNSIGNED; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_LONG; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UNSIGNED_LONG; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_LONG_LONG; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UNSIGNED_LONG_LONG; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_FLOAT; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_DOUBLE; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_LONG_DOUBLE; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_INT8_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_INT16_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_INT32_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_INT64_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UINT8_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UINT16_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UINT32_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_UINT64_T; + } else if constexpr (std::is_same::value) { + mpi_type = MPI_C_BOOL; + } else if constexpr (std::is_same>::value) { + mpi_type = MPI_C_COMPLEX; + } else if constexpr (std::is_same>::value) { + mpi_type = MPI_C_DOUBLE_COMPLEX; + } else if constexpr (std::is_same>::value) { + mpi_type = MPI_C_LONG_DOUBLE_COMPLEX; + } // if + + assert(mpi_type != MPI_DATATYPE_NULL); + return mpi_type; + +} // mpi_get_type() + +#endif \ No newline at end of file diff --git a/brahmap/utilities/__init__.py b/brahmap/utilities/__init__.py index 3aa7ddb..acaaced 100644 --- a/brahmap/utilities/__init__.py +++ b/brahmap/utilities/__init__.py @@ -22,7 +22,7 @@ from .process_time_samples import ProcessTimeSamples, SolverType -from .mpi import Initialize +from .mpi import Initialize, MPI_RAISE_EXCEPTION __all__ = [ "is_sorted", @@ -41,4 +41,5 @@ "SolverType", "TypeChangeWarning", "Initialize", + "MPI_RAISE_EXCEPTION", ] diff --git a/brahmap/utilities/mpi.py b/brahmap/utilities/mpi.py index 4a87955..3e08237 100644 --- a/brahmap/utilities/mpi.py +++ b/brahmap/utilities/mpi.py @@ -1,15 +1,7 @@ import os import brahmap -import mpi4py - -mpi4py.rc.initialize = False -mpi4py.rc.finalize = False - -from mpi4py import MPI # noqa: E402 - -if MPI.Is_initialized() is False: - MPI.Init_thread(required=MPI.THREAD_FUNNELED) +from mpi4py import MPI def Initialize(communicator=None, raise_exception_per_process: bool = True): @@ -40,7 +32,7 @@ def MPI_RAISE_EXCEPTION( exception, message, ): - """Will raise `exception` with `message` if the `condition` is false. + """Will raise `exception` with `message` if the `condition` is `True`. Args: condition (_type_): The condition to be evaluated @@ -53,12 +45,12 @@ def MPI_RAISE_EXCEPTION( """ if brahmap.bMPI.raise_exception_per_process: - if condition is False: + if condition is True: error_str = f"Exception raised by MPI rank {brahmap.bMPI.rank}\n" raise exception(error_str + message) else: exception_count = brahmap.bMPI.comm.reduce(condition, MPI.SUM, 0) if brahmap.bMPI.rank == 0: - error_str = f"Exception raised by {brahmap.bMPI.comm.size - exception_count} MPI process(es)\n" + error_str = f"Exception raised by {exception_count} MPI process(es)\n" raise exception(error_str + message) diff --git a/brahmap/utilities/process_time_samples.py b/brahmap/utilities/process_time_samples.py index 553681f..9eaeae6 100644 --- a/brahmap/utilities/process_time_samples.py +++ b/brahmap/utilities/process_time_samples.py @@ -2,6 +2,8 @@ import numpy as np import warnings +import brahmap + from brahmap.utilities.tools import TypeChangeWarning from brahmap.utilities import bash_colors @@ -9,6 +11,9 @@ from brahmap._extensions import repixelize +from mpi4py import MPI + + class SolverType(IntEnum): I = 1 # noqa: E741 QU = 2 @@ -28,6 +33,9 @@ def __init__( dtype_float=None, update_pointings_inplace: bool = True, ): + if brahmap.bMPI is None: + brahmap.Initialize() + self.npix = npix self.nsamples = len(pointings) @@ -42,18 +50,20 @@ def __init__( if self.pointings_flag is None: self.pointings_flag = np.ones(self.nsamples, dtype=bool) - if len(self.pointings_flag) != self.nsamples: - raise AssertionError( - f"Size of `pointings_flag` must be equal to the size of `pointings` array:\nlen(pointings_flag) = {len(pointings_flag)}\nlen(pointings) = {self.nsamples}" - ) + brahmap.MPI_RAISE_EXCEPTION( + condition=(len(self.pointings_flag) != self.nsamples), + exception=AssertionError, + message=f"Size of `pointings_flag` must be equal to the size of `pointings` array:\nlen(pointings_flag) = {len(pointings_flag)}\nlen(pointings) = {self.nsamples}", + ) self.threshold = threshold self.solver_type = solver_type - if self.solver_type not in [1, 2, 3]: - raise ValueError( - "Invalid `solver_type`!!!\n`solver_type` must be either SolverType.I, SolverType.QU or SolverType.IQU (equivalently 1, 2 or 3)." - ) + brahmap.MPI_RAISE_EXCEPTION( + condition=(self.solver_type not in [1, 2, 3]), + exception=ValueError, + message="Invalid `solver_type`!!!\n`solver_type` must be either SolverType.I, SolverType.QU or SolverType.IQU (equivalently 1, 2 or 3).", + ) # setting the dtype for the `float` arrays: if one or both of `noise_weights` and `pol_angles` are supplied, the `dtype_float` will be inferred from them. Otherwise, the it will be set to `np.float64` if dtype_float is not None: @@ -71,29 +81,33 @@ def __init__( if noise_weights is None: noise_weights = np.ones(self.nsamples, dtype=self.dtype_float) - if len(noise_weights) != self.nsamples: - raise AssertionError( - f"Size of `noise_weights` must be equal to the size of `pointings` array:\nlen(noise_weigths) = {len(noise_weights)}\nlen(pointings) = {self.nsamples}" - ) + brahmap.MPI_RAISE_EXCEPTION( + condition=(len(noise_weights) != self.nsamples), + exception=AssertionError, + message=f"Size of `noise_weights` must be equal to the size of `pointings` array:\nlen(noise_weigths) = {len(noise_weights)}\nlen(pointings) = {self.nsamples}", + ) if noise_weights.dtype != self.dtype_float: - warnings.warn( - f"dtype of `noise_weights` will be changed to {self.dtype_float}", - TypeChangeWarning, - ) + if brahmap.bMPI.rank == 0: + warnings.warn( + f"dtype of `noise_weights` will be changed to {self.dtype_float}", + TypeChangeWarning, + ) noise_weights = noise_weights.astype(dtype=self.dtype_float, copy=False) if self.solver_type != 1: - if len(pol_angles) != self.nsamples: - raise AssertionError( - f"Size of `pol_angles` must be equal to the size of `pointings` array:\nlen(pol_angles) = {len(pol_angles)}\nlen(pointings) = {self.nsamples}" - ) + brahmap.MPI_RAISE_EXCEPTION( + condition=(len(pol_angles) != self.nsamples), + exception=AssertionError, + message=f"Size of `pol_angles` must be equal to the size of `pointings` array:\nlen(pol_angles) = {len(pol_angles)}\nlen(pointings) = {self.nsamples}", + ) if pol_angles.dtype != self.dtype_float: - warnings.warn( - f"dtype of `pol_angles` will be changed to {self.dtype_float}", - TypeChangeWarning, - ) + if brahmap.bMPI.rank == 0: + warnings.warn( + f"dtype of `pol_angles` will be changed to {self.dtype_float}", + TypeChangeWarning, + ) pol_angles = pol_angles.astype(dtype=self.dtype_float, copy=False) self._compute_weights( @@ -104,22 +118,27 @@ def __init__( self._repixelization() self._flag_bad_pixel_samples() - bc = bash_colors() - print(bc.header(f"{bc.bold(' ProcessTimeSamples Summary '):-^60}")) - print( - bc.blue(bc.bold(f"Read {self.nsamples} time samples for npix={self.npix}")) - ) - print( - bc.blue(bc.bold(f"Found {self.npix - self.new_npix} pathological pixels")) - ) - print( - bc.blue( - bc.bold( - f"Map-maker will take into account only {self.new_npix} pixels." + if brahmap.bMPI.rank == 0: + bc = bash_colors() + print(bc.header(f"{bc.bold(' ProcessTimeSamples Summary '):-^60}")) + print( + bc.blue( + bc.bold(f"Read {self.nsamples} time samples for npix={self.npix}") ) ) - ) - print(bc.header("---" * 20)) + print( + bc.blue( + bc.bold(f"Found {self.npix - self.new_npix} pathological pixels") + ) + ) + print( + bc.blue( + bc.bold( + f"Map-maker will take into account only {self.new_npix} pixels." + ) + ) + ) + print(bc.header("---" * 20)) def get_hit_counts(self): """Returns hit counts of the pixel indices""" @@ -127,6 +146,8 @@ def get_hit_counts(self): for idx in range(self.nsamples): hit_counts_newidx[self.pointings[idx]] += self.pointings_flag[idx] + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, hit_counts_newidx, MPI.SUM) + hit_counts = np.ma.masked_array( data=np.zeros(self.npix), mask=np.logical_not(self.pixel_flag), @@ -163,6 +184,7 @@ def _compute_weights(self, pol_angles: np.ndarray, noise_weights: np.ndarray): observed_pixels=self.observed_pixels, __old2new_pixel=self.__old2new_pixel, pixel_flag=self.pixel_flag, + comm=brahmap.bMPI.comm, ) else: @@ -190,6 +212,7 @@ def _compute_weights(self, pol_angles: np.ndarray, noise_weights: np.ndarray): weighted_cos_sq=self.weighted_cos_sq, weighted_sincos=self.weighted_sincos, one_over_determinant=self.one_over_determinant, + comm=brahmap.bMPI.comm, ) elif self.solver_type == SolverType.IQU: @@ -212,6 +235,7 @@ def _compute_weights(self, pol_angles: np.ndarray, noise_weights: np.ndarray): weighted_sin=self.weighted_sin, weighted_cos=self.weighted_cos, one_over_determinant=self.one_over_determinant, + comm=brahmap.bMPI.comm, ) self.new_npix = compute_weights.get_pixel_mask_pol( diff --git a/pyproject.toml b/pyproject.toml index e5e120e..445dbca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,10 @@ [build-system] -requires = ["setuptools", "setuptools_scm"] +requires = [ + "setuptools", + "setuptools_scm", + "mpi4py", # `mpi4py` is needed to use `mpi4py.get_include()` in `setup.py` +] + build-backend = "setuptools.build_meta" [project] @@ -17,6 +22,7 @@ dependencies = [ "numpy", "scipy", "healpy", + "mpi4py", "ruff", "pre-commit", "pytest", diff --git a/setup.py b/setup.py index df6d335..a13e011 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ import os from setuptools import Extension, setup +import mpi4py # g++ -O3 -march=native -Wall -shared -std=c++14 -fPIC $(python3 -m pybind11 --includes) example9.cpp -o example9$(python3-config --extension-suffix) @@ -10,6 +11,7 @@ include_dirs=[ os.path.join("brahmap", "_extensions"), os.path.join("extern", "pybind11", "include"), + os.path.join(mpi4py.get_include()), ], define_macros=None, extra_compile_args=[ @@ -20,7 +22,6 @@ "-std=c++20", "-fPIC", "-fvisibility=hidden", - "-lm", ], ) diff --git a/tests/helper_ComputeWeights.py b/tests/helper_ComputeWeights.py index 391e9f4..1421a82 100644 --- a/tests/helper_ComputeWeights.py +++ b/tests/helper_ComputeWeights.py @@ -1,4 +1,5 @@ import numpy as np +from mpi4py import MPI def computeweights_pol_I( @@ -8,6 +9,7 @@ def computeweights_pol_I( pointings_flag: np.ndarray, noise_weights: np.ndarray, dtype_float, + comm, ): weighted_counts = np.zeros(npix, dtype=dtype_float) @@ -17,6 +19,8 @@ def computeweights_pol_I( if pointings_flag[idx]: weighted_counts[pixel] += noise_weights[idx] + comm.Allreduce(MPI.IN_PLACE, weighted_counts, MPI.SUM) + observed_pixels = np.where(weighted_counts > 0)[0] new_npix = len(observed_pixels) @@ -42,6 +46,7 @@ def computeweights_pol_QU( noise_weights: np.ndarray, pol_angles: np.ndarray, dtype_float, + comm, ): weighted_counts = np.zeros(npix, dtype=dtype_float) weighted_sin_sq = np.zeros(npix, dtype=dtype_float) @@ -61,6 +66,11 @@ def computeweights_pol_QU( weighted_cos_sq[pixel] += noise_weights[idx] * cos2phi[idx] * cos2phi[idx] weighted_sincos[pixel] += noise_weights[idx] * sin2phi[idx] * cos2phi[idx] + comm.Allreduce(MPI.IN_PLACE, weighted_counts, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_sin_sq, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_cos_sq, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_sincos, MPI.SUM) + one_over_determinant = (weighted_cos_sq * weighted_sin_sq) - ( weighted_sincos * weighted_sincos ) @@ -84,6 +94,7 @@ def computeweights_pol_IQU( noise_weights: np.ndarray, pol_angles: np.ndarray, dtype_float, + comm, ): weighted_counts = np.zeros(npix, dtype=dtype_float) weighted_sin_sq = np.zeros(npix, dtype=dtype_float) @@ -107,6 +118,13 @@ def computeweights_pol_IQU( weighted_sin[pixel] += noise_weights[idx] * sin2phi[idx] weighted_cos[pixel] += noise_weights[idx] * cos2phi[idx] + comm.Allreduce(MPI.IN_PLACE, weighted_counts, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_sin, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_cos, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_sin_sq, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_cos_sq, MPI.SUM) + comm.Allreduce(MPI.IN_PLACE, weighted_sincos, MPI.SUM) + one_over_determinant = ( weighted_counts * (weighted_cos_sq * weighted_sin_sq - weighted_sincos * weighted_sincos) diff --git a/tests/helper_ProcessTimeSamples.py b/tests/helper_ProcessTimeSamples.py index 26fa4d1..8d950e6 100644 --- a/tests/helper_ProcessTimeSamples.py +++ b/tests/helper_ProcessTimeSamples.py @@ -5,8 +5,11 @@ import helper_ComputeWeights as cw import helper_Repixelization as rp +import brahmap from brahmap.utilities import TypeChangeWarning +from mpi4py import MPI + class SolverType(IntEnum): I = 1 # noqa: E741 @@ -27,6 +30,9 @@ def __init__( dtype_float=None, update_pointings_inplace: bool = True, ): + if brahmap.bMPI is None: + brahmap.Initialize() + self.npix = npix self.nsamples = len(pointings) @@ -87,6 +93,8 @@ def get_hit_counts(self, mask_fill_value=np.nan): for idx in range(self.nsamples): hit_counts_newidx[self.pointings[idx]] += self.pointings_flag[idx] + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, hit_counts_newidx, MPI.SUM) + hit_counts = np.ma.masked_array( data=np.zeros(self.npix, dtype=int), mask=np.logical_not(self.pixel_flag), @@ -120,6 +128,7 @@ def _compute_weights(self, pol_angles, noise_weights): pointings_flag=self.pointings_flag, noise_weights=noise_weights, dtype_float=self.dtype_float, + comm=brahmap.bMPI.comm, ) else: @@ -140,6 +149,7 @@ def _compute_weights(self, pol_angles, noise_weights): noise_weights=noise_weights, pol_angles=pol_angles, dtype_float=self.dtype_float, + comm=brahmap.bMPI.comm, ) elif self.solver_type == SolverType.IQU: @@ -161,6 +171,7 @@ def _compute_weights(self, pol_angles, noise_weights): noise_weights=noise_weights, pol_angles=pol_angles, dtype_float=self.dtype_float, + comm=brahmap.bMPI.comm, ) ( diff --git a/tests/test_BlkDiagPrecondLO.py b/tests/test_BlkDiagPrecondLO.py index ded4609..b7825be 100644 --- a/tests/test_BlkDiagPrecondLO.py +++ b/tests/test_BlkDiagPrecondLO.py @@ -5,11 +5,16 @@ import helper_BlkDiagPrecondLO as bdplo import helper_ProcessTimeSamples as hpts +brahmap.Initialize() + class InitCommonParams: - np.random.seed(6543) + np.random.seed(65434 + brahmap.bMPI.rank) npix = 128 - nsamples = npix * 6 + nsamples_global = npix * 6 + + div, rem = divmod(nsamples_global, brahmap.bMPI.size) + nsamples = div + (brahmap.bMPI.rank < rem) pointings_flag = np.ones(nsamples, dtype=bool) bad_samples = np.random.randint(low=0, high=nsamples, size=npix) @@ -205,8 +210,8 @@ def test_I(self, initint, initfloat, rtol): @pytest.mark.parametrize( "initint, initfloat, rtol", [ - (InitInt32Params(), InitFloat32Params(), 1.5e-4), - (InitInt64Params(), InitFloat32Params(), 1.5e-4), + (InitInt32Params(), InitFloat32Params(), 1.5e-3), + (InitInt64Params(), InitFloat32Params(), 1.5e-3), (InitInt32Params(), InitFloat64Params(), 1.5e-5), (InitInt64Params(), InitFloat64Params(), 1.5e-5), ], @@ -252,8 +257,8 @@ def test_QU(self, initint, initfloat, rtol): @pytest.mark.parametrize( "initint, initfloat, rtol", [ - (InitInt32Params(), InitFloat32Params(), 1.5e-3), - (InitInt64Params(), InitFloat32Params(), 1.5e-3), + (InitInt32Params(), InitFloat32Params(), 1.0e-3), + (InitInt64Params(), InitFloat32Params(), 1.0e-3), (InitInt32Params(), InitFloat64Params(), 1.5e-5), (InitInt64Params(), InitFloat64Params(), 1.5e-5), ], diff --git a/tests/test_BlkDiagPrecondLO_tools_cpp.py b/tests/test_BlkDiagPrecondLO_tools_cpp.py index 1e9515c..9f4f0e6 100644 --- a/tests/test_BlkDiagPrecondLO_tools_cpp.py +++ b/tests/test_BlkDiagPrecondLO_tools_cpp.py @@ -1,15 +1,22 @@ import pytest import numpy as np + +import brahmap from brahmap._extensions import BlkDiagPrecondLO_tools import helper_ProcessTimeSamples as hpts import helper_BlkDiagPrecondLO_tools as bdplo_tools +brahmap.Initialize() + class InitCommonParams: - np.random.seed(987) + np.random.seed(987 + brahmap.bMPI.rank) npix = 128 - nsamples = npix * 6 + nsamples_global = npix * 6 + + div, rem = divmod(nsamples_global, brahmap.bMPI.size) + nsamples = div + (brahmap.bMPI.rank < rem) pointings_flag = np.ones(nsamples, dtype=bool) bad_samples = np.random.randint(low=0, high=nsamples, size=npix) diff --git a/tests/test_PointingLO.py b/tests/test_PointingLO.py index f11bbc4..85f31d8 100644 --- a/tests/test_PointingLO.py +++ b/tests/test_PointingLO.py @@ -7,9 +7,12 @@ class InitCommonParams: - np.random.seed(54321) + np.random.seed(54321 + brahmap.bMPI.rank) npix = 128 - nsamples = npix * 6 + nsamples_global = npix * 6 + + div, rem = divmod(nsamples_global, brahmap.bMPI.size) + nsamples = div + (brahmap.bMPI.rank < rem) pointings_flag = np.ones(nsamples, dtype=bool) bad_samples = np.random.randint(low=0, high=nsamples, size=npix) diff --git a/tests/test_ProcessTimeSamples.py b/tests/test_ProcessTimeSamples.py index 42bddd0..fff205b 100644 --- a/tests/test_ProcessTimeSamples.py +++ b/tests/test_ProcessTimeSamples.py @@ -1,14 +1,23 @@ import pytest import numpy as np +import brahmap import brahmap.utilities as bmutils + import helper_ProcessTimeSamples as hpts +from mpi4py import MPI + +brahmap.Initialize() + class InitCommonParams: - np.random.seed(12345) + np.random.seed(12345 + brahmap.bMPI.rank) npix = 128 - nsamples = npix * 6 + nsamples_global = npix * 6 + + div, rem = divmod(nsamples_global, brahmap.bMPI.size) + nsamples = div + (brahmap.bMPI.rank < rem) pointings_flag = np.ones(nsamples, dtype=bool) bad_samples = np.random.randint(low=0, high=nsamples, size=npix) @@ -64,8 +73,8 @@ def __init__(self) -> None: @pytest.mark.parametrize( "initint, initfloat, rtol", [ - (InitInt32Params(), InitFloat32Params(), 1.5e-4), - (InitInt64Params(), InitFloat32Params(), 1.5e-4), + (InitInt32Params(), InitFloat32Params(), 1.5e-3), + (InitInt64Params(), InitFloat32Params(), 1.5e-3), (InitInt32Params(), InitFloat64Params(), 1.5e-5), (InitInt64Params(), InitFloat64Params(), 1.5e-5), ], @@ -208,8 +217,8 @@ def test_ProcessTimeSamples_IQU_Cpp(self, initint, initfloat, rtol): @pytest.mark.parametrize( "initint, initfloat, rtol", [ - (InitInt32Params(), InitFloat32Params(), 1.5e-4), - (InitInt64Params(), InitFloat32Params(), 1.5e-4), + (InitInt32Params(), InitFloat32Params(), 1.5e-3), + (InitInt64Params(), InitFloat32Params(), 1.5e-3), (InitInt32Params(), InitFloat64Params(), 1.5e-5), (InitInt64Params(), InitFloat64Params(), 1.5e-5), ], @@ -235,6 +244,8 @@ def test_ProcessTimeSamples_I(self, initint, initfloat, rtol): pixel = PTS.pointings[idx] weighted_counts[pixel] += initfloat.noise_weights[idx] + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_counts, MPI.SUM) + np.testing.assert_allclose(PTS.weighted_counts, weighted_counts, rtol=rtol) def test_ProcessTimeSamples_QU(self, initint, initfloat, rtol): @@ -273,6 +284,11 @@ def test_ProcessTimeSamples_QU(self, initint, initfloat, rtol): initfloat.noise_weights[idx] * sin2phi[idx] * cos2phi[idx] ) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_counts, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_sin_sq, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_cos_sq, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_sincos, MPI.SUM) + one_over_determinant = 1.0 / ( (weighted_cos_sq * weighted_sin_sq) - (weighted_sincos * weighted_sincos) ) @@ -327,6 +343,13 @@ def test_ProcessTimeSamples_IQU(self, initint, initfloat, rtol): weighted_sin[pixel] += initfloat.noise_weights[idx] * sin2phi[idx] weighted_cos[pixel] += initfloat.noise_weights[idx] * cos2phi[idx] + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_counts, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_sin, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_cos, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_sin_sq, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_cos_sq, MPI.SUM) + brahmap.bMPI.comm.Allreduce(MPI.IN_PLACE, weighted_sincos, MPI.SUM) + one_over_determinant = 1.0 / ( weighted_counts * (weighted_cos_sq * weighted_sin_sq - weighted_sincos * weighted_sincos) diff --git a/tests/test_compute_weights_cpp.py b/tests/test_compute_weights_cpp.py index 4cdbe31..c05d898 100644 --- a/tests/test_compute_weights_cpp.py +++ b/tests/test_compute_weights_cpp.py @@ -1,14 +1,21 @@ import pytest import numpy as np + +import brahmap from brahmap._extensions import compute_weights import helper_ComputeWeights as cw +brahmap.Initialize() + class InitCommonParams: - np.random.seed(1234) + np.random.seed(1234 + brahmap.bMPI.rank) npix = 128 - nsamples = npix * 6 + nsamples_global = npix * 6 + + div, rem = divmod(nsamples_global, brahmap.bMPI.size) + nsamples = div + (brahmap.bMPI.rank < rem) pointings_flag = np.ones(nsamples, dtype=bool) bad_samples = np.random.randint(low=0, high=nsamples, size=npix) @@ -87,6 +94,7 @@ def test_compute_weights_pol_I(self, initint, initfloat, rtol): cpp_observed_pixels, cpp_old2new_pixel, cpp_pixel_flag, + brahmap.bMPI.comm, ) ( @@ -102,6 +110,7 @@ def test_compute_weights_pol_I(self, initint, initfloat, rtol): self.pointings_flag, initfloat.noise_weights, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) cpp_observed_pixels.resize(cpp_new_npix, refcheck=False) @@ -136,6 +145,7 @@ def test_compute_weights_pol_QU(self, initint, initfloat, rtol): cpp_weighted_cos_sq, cpp_weighted_sincos, cpp_one_over_determinant, + brahmap.bMPI.comm, ) ( @@ -154,6 +164,7 @@ def test_compute_weights_pol_QU(self, initint, initfloat, rtol): initfloat.noise_weights, initfloat.pol_angles, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) np.testing.assert_allclose(cpp_weighted_counts, py_weighted_counts, rtol=rtol) @@ -191,6 +202,7 @@ def test_compute_weights_pol_IQU(self, initint, initfloat, rtol): cpp_weighted_sin, cpp_weighted_cos, cpp_one_over_determinant, + brahmap.bMPI.comm, ) ( @@ -211,6 +223,7 @@ def test_compute_weights_pol_IQU(self, initint, initfloat, rtol): initfloat.noise_weights, initfloat.pol_angles, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) np.testing.assert_allclose(cpp_weighted_counts, py_weighted_counts, rtol=rtol) @@ -239,6 +252,7 @@ def test_get_pix_mask_pol_QU(self, initint, initfloat, rtol): initfloat.noise_weights, initfloat.pol_angles, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) cpp_observed_pixels = np.zeros(self.npix, initint.dtype) @@ -296,6 +310,7 @@ def test_get_pix_mask_pol_IQU(self, initint, initfloat, rtol): initfloat.noise_weights, initfloat.pol_angles, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) cpp_observed_pixels = np.zeros(self.npix, initint.dtype) diff --git a/tests/test_repixelization_cpp.py b/tests/test_repixelization_cpp.py index bd9047c..5c4c114 100644 --- a/tests/test_repixelization_cpp.py +++ b/tests/test_repixelization_cpp.py @@ -1,5 +1,7 @@ import pytest import numpy as np + +import brahmap from brahmap._extensions import repixelize import helper_ComputeWeights as cw @@ -7,9 +9,12 @@ class InitCommonParams: - np.random.seed(1234) + np.random.seed(1234 + brahmap.bMPI.rank) npix = 128 - nsamples = npix * 6 + nsamples_global = npix * 6 + + div, rem = divmod(nsamples_global, brahmap.bMPI.size) + nsamples = div + (brahmap.bMPI.rank < rem) pointings_flag = np.ones(nsamples, dtype=bool) bad_samples = np.random.randint(low=0, high=nsamples, size=npix) @@ -80,6 +85,7 @@ def test_repixelize_pol_I(self, initint, initfloat, rtol): self.pointings_flag, initfloat.noise_weights, initfloat.dtype, + comm=brahmap.bMPI.comm, ) cpp_weighted_counts = py_weighted_counts.copy() @@ -111,6 +117,7 @@ def test_repixelize_pol_QU(self, initint, initfloat, rtol): initfloat.noise_weights, initfloat.pol_angles, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) new_npix, observed_pixels, __, __ = cw.get_pix_mask_pol( @@ -187,6 +194,7 @@ def test_repixelize_pol_IQU(self, initint, initfloat, rtol): initfloat.noise_weights, initfloat.pol_angles, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) new_npix, observed_pixels, __, __ = cw.get_pix_mask_pol( @@ -286,6 +294,7 @@ def test_flag_bad_pixel_samples(self, initint, initfloat): initfloat.noise_weights, initfloat.pol_angles, dtype_float=initfloat.dtype, + comm=brahmap.bMPI.comm, ) __, __, old2new_pixel, pixel_flag = cw.get_pix_mask_pol(