diff --git a/mpi/MPIFileIO.py b/mpi/MPIFileIO.py index 462776c7..96ce6d99 100644 --- a/mpi/MPIFileIO.py +++ b/mpi/MPIFileIO.py @@ -1,4 +1,3 @@ -from mpi4py import MPI from spt3g import core import random @@ -117,15 +116,20 @@ def __call__(self, frame): return [] # Terminate processing on IO nodes @core.pipesegment -def MPIIODistributor(pipe, mpicomm=MPI.COMM_WORLD, n_io=10, files=[]): +def MPIIODistributor(pipe, mpicomm=None, n_io=10, files=[]): ''' - Read files from disk using the first n_io processes in mpicomm, with - processing of frames in those files occurring on the other processes - in mpicomm. See documentation for MPIFileReader for the format of - the files argument and MPIFrameParallelizer for information on the - semantics of processing. Add this as the first module in your pipeline - in place of core.G3Reader. + Read files from disk using the first n_io processes in mpicomm (COMM_WORLD + by default), with processing of frames in those files occurring on the other + processes in mpicomm. See documentation for MPIFileReader for the format of + the files argument and MPIFrameParallelizer for information on the semantics + of processing. Add this as the first module in your pipeline in place of + core.G3Reader. ''' + if mpicomm is None: + from mpi4py import MPI + + mpicomm = MPI.COMM_WORLD + subcomm = mpicomm.Split(mpicomm.rank < n_io, mpicomm.rank) if mpicomm.rank < n_io: pipe.Add(MPIFileReader, mpicomm=subcomm, files=files)