From 65e05be8155387c83e3182f7d479906622f83ba4 Mon Sep 17 00:00:00 2001 From: Sasha Rahlin Date: Wed, 4 Oct 2023 13:27:41 -0500 Subject: [PATCH] Hide mpi4py import --- mpi/MPIFileIO.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/mpi/MPIFileIO.py b/mpi/MPIFileIO.py index 462776c7..96ce6d99 100644 --- a/mpi/MPIFileIO.py +++ b/mpi/MPIFileIO.py @@ -1,4 +1,3 @@ -from mpi4py import MPI from spt3g import core import random @@ -117,15 +116,20 @@ def __call__(self, frame): return [] # Terminate processing on IO nodes @core.pipesegment -def MPIIODistributor(pipe, mpicomm=MPI.COMM_WORLD, n_io=10, files=[]): +def MPIIODistributor(pipe, mpicomm=None, n_io=10, files=[]): ''' - Read files from disk using the first n_io processes in mpicomm, with - processing of frames in those files occurring on the other processes - in mpicomm. See documentation for MPIFileReader for the format of - the files argument and MPIFrameParallelizer for information on the - semantics of processing. Add this as the first module in your pipeline - in place of core.G3Reader. + Read files from disk using the first n_io processes in mpicomm (COMM_WORLD + by default), with processing of frames in those files occurring on the other + processes in mpicomm. See documentation for MPIFileReader for the format of + the files argument and MPIFrameParallelizer for information on the semantics + of processing. Add this as the first module in your pipeline in place of + core.G3Reader. ''' + if mpicomm is None: + from mpi4py import MPI + + mpicomm = MPI.COMM_WORLD + subcomm = mpicomm.Split(mpicomm.rank < n_io, mpicomm.rank) if mpicomm.rank < n_io: pipe.Add(MPIFileReader, mpicomm=subcomm, files=files)