Skip to content

Commit

Permalink
added a function that takes a mpi_comm OR initalizes mpi env, and set…
Browse files Browse the repository at this point in the history
…s the global attribute bMPI
  • Loading branch information
anand-avinash committed May 28, 2024
1 parent f2323d5 commit 54c203b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
13 changes: 12 additions & 1 deletion brahmap/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from . import interfaces, utilities, linop, mapmakers, _extensions

__all__ = ["interfaces", "utilities", "linop", "mapmakers", "_extensions"]
from .utilities import Initialize

bMPI = None

__all__ = [
"interfaces",
"utilities",
"linop",
"mapmakers",
"_extensions",
"Initialize",
]
3 changes: 3 additions & 0 deletions brahmap/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from .process_time_samples import ProcessTimeSamples, SolverType

from .mpi import Initialize

__all__ = [
"is_sorted",
"bash_colors",
Expand All @@ -38,4 +40,5 @@
"ProcessTimeSamples",
"SolverType",
"TypeChangeWarning",
"Initialize",
]
64 changes: 64 additions & 0 deletions brahmap/utilities/mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import brahmap

import mpi4py

mpi4py.rc.initialize = False
mpi4py.rc.finalize = False

from mpi4py import MPI # noqa: E402

if MPI.Is_initialized() is False:
MPI.Init_thread(required=MPI.THREAD_FUNNELED)


def Initialize(communicator=None, raise_exception_per_process: bool = True):
if brahmap.bMPI is None:
brahmap.bMPI = _MPI(
comm=communicator, raise_exception_per_process=raise_exception_per_process
)


class _MPI(object):
def __init__(self, comm, raise_exception_per_process: bool) -> None:
if comm is None:
self.comm = MPI.COMM_WORLD
else:
self.comm = comm
self.size = self.comm.size
self.rank = self.comm.rank
self.raise_exception_per_process = raise_exception_per_process

if "OMP_NUM_THREADS" in os.environ:
self.nthreads_per_process = os.environ.get("OMP_NUM_THREADS")
else:
self.nthreads_per_process = 1


def MPI_RAISE_EXCEPTION(
condition,
exception,
message,
):
"""Will raise `exception` with `message` if the `condition` is false.
Args:
condition (_type_): The condition to be evaluated
exception (_type_): The exception to throw
message (_type_): The message to pass to the `Exception`
Raises:
exception: _description_
exception: _description_
"""

if brahmap.bMPI.raise_exception_per_process:
if condition is False:
error_str = f"Exception raised by MPI rank {brahmap.bMPI.rank}\n"
raise exception(error_str + message)
else:
exception_count = brahmap.bMPI.comm.reduce(condition, MPI.SUM, 0)

if brahmap.bMPI.rank == 0:
error_str = f"Exception raised by {brahmap.bMPI.comm.size - exception_count} MPI process(es)\n"
raise exception(error_str + message)

0 comments on commit 54c203b

Please sign in to comment.