From bb080dca8e6f50658bde5b9ba9d1c95a6dac3e1f Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 2 Apr 2024 14:25:51 +0000 Subject: [PATCH 1/5] compiler: Sketch C-level MPI reductions --- devito/ir/clusters/algorithms.py | 87 +++++++++++++++++++++++++++++--- devito/ir/clusters/cluster.py | 8 ++- devito/ir/stree/algorithms.py | 12 +++-- devito/ir/support/basic.py | 7 +++ devito/mpi/reduction_scheme.py | 44 ++++++++++++++++ devito/mpi/routines.py | 41 +++++++++++++-- devito/operator/operator.py | 2 +- devito/passes/iet/mpi.py | 12 ++++- tests/test_error_checking.py | 23 +++++++++ 9 files changed, 217 insertions(+), 19 deletions(-) create mode 100644 devito/mpi/reduction_scheme.py diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 043f0121e5..bce705b40a 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -8,12 +8,13 @@ from devito.exceptions import InvalidOperator from devito.finite_differences.elementary import Max, Min from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange, - pull_dims) + pull_dims, null_ispace) from devito.ir.equations import OpMin, OpMax from devito.ir.clusters.analysis import analyze from devito.ir.clusters.cluster import Cluster, ClusterGroup from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass from devito.mpi.halo_scheme import HaloScheme, HaloTouch +from devito.mpi.reduction_scheme import DistributedReduction from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace, xreplace_indices) from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten, @@ -48,7 +49,7 @@ def clusterize(exprs, **kwargs): clusters = normalize(clusters, **kwargs) # Derive the necessary communications for distributed-memory parallelism - clusters = Communications().process(clusters) + clusters = communications(clusters) return ClusterGroup(clusters) @@ -365,11 +366,24 @@ def rule(size, e): return processed -class Communications(Queue): - +@timed_pass(name='communications') +def communications(clusters): """ Enrich a sequence of Clusters by adding special Clusters representing data - communications, or "halo exchanges", for distributed parallelism. + communications for distributed parallelism. + """ + clusters = HaloComms().process(clusters) + clusters = reduction_comms(clusters) + + return clusters + + +class Comms(Queue): + #TODO: MAYBE DROP ME + + """ + Abstract base class for injecting Clusters representing communications + for distributed-memory parallelism. """ _q_guards_in_key = True @@ -377,7 +391,13 @@ class Communications(Queue): B = Symbol(name='⊥') - @timed_pass(name='communications') + +class HaloComms(Comms): + + """ + A specialization of Comms to handle halo exchanges. + """ + def process(self, clusters): return self._process_fatd(clusters, 1, seen=set()) @@ -432,6 +452,54 @@ def callback(self, clusters, prefix, seen=None): return processed +def reduction_comms(clusters): + # Detect the underlying Grid + #TODO: pretty rudimentary, but it's a start + for c in clusters: + try: + grid = c.grid + break + except ValueError: + continue + else: + return clusters + + # Detect global reductions along the distributed Dimensions + found = {} + for c in clusters: + if not any(grid.is_distributed(d) for d in c.ispace.itdims): + continue + + for e in c.exprs: + op = e.operation + if op is None: + continue + elif found.get(e.lhs, op) != op: + raise ValueError("Inconsistent reduction operations") + else: + found[e.lhs] = e.operation + + # Place global reductions right before they're required + processed = [] + for c in clusters: + for var, op in list(found.items()): + if var in c.scope.read_only: + expr = Eq(var, DistributedReduction(var, op=op, grid=grid)) + processed.append(c.rebuild(exprs=expr)) + + found.pop(var) + + processed.append(c) + + # Leftover reductions are placed at the very end + while found: + var, op = found.popitem() + expr = Eq(var, DistributedReduction(var, op=op, grid=grid)) + processed.append(Cluster(exprs=[expr], ispace=null_ispace)) + + return processed + + def normalize(clusters, **kwargs): options = kwargs['options'] sregistry = kwargs['sregistry'] @@ -562,7 +630,12 @@ def _normalize_reductions_dense(cluster, sregistry, mapper): # because the Function might be padded, and reduction operations # require, in general, the data values to be contiguous name = sregistry.make_name() - a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims) + try: + grid = cluster.grid + except ValueError: + grid = None + a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims, + grid=grid) processed.extend([Eq(a.indexify(), rhs), e.func(lhs, a.indexify())]) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index a96fcf5e1c..5e56be055a 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -11,6 +11,7 @@ normalize_properties, normalize_syncs, minimum, maximum, null_ispace) from devito.mpi.halo_scheme import HaloScheme, HaloTouch +from devito.mpi.reduction_scheme import DistributedReduction from devito.symbolics import estimate_cost from devito.tools import as_tuple, flatten, frozendict, infer_dtype from devito.types import WeakFence, CriticalRegion @@ -232,12 +233,17 @@ def is_wild(self): """ True if encoding a non-mathematical operation, False otherwise. """ - return self.is_halo_touch or self.is_fence + return self.is_halo_touch or self.is_dist_reduce or self.is_fence @property def is_halo_touch(self): return self.exprs and all(isinstance(e.rhs, HaloTouch) for e in self.exprs) + @property + def is_dist_reduce(self): + return self.exprs and all(isinstance(e.rhs, DistributedReduction) + for e in self.exprs) + @property def is_fence(self): return self.is_weak_fence or self.is_critical_region diff --git a/devito/ir/stree/algorithms.py b/devito/ir/stree/algorithms.py index 99113d9bb8..73ac219166 100644 --- a/devito/ir/stree/algorithms.py +++ b/devito/ir/stree/algorithms.py @@ -136,12 +136,11 @@ def stree_build(clusters, profiler=None, **kwargs): def preprocess(clusters, options=None, **kwargs): """ - Lower the so-called "wild" Clusters, that is objects not representing a set - of mathematical operations. This boils down to: + Lower the so-called "wild" Clusters, that is objects not representing + mathematical operations. This boils down to: - * Moving the HaloTouch's from `clusters` into a mapper `M: {HT -> C}`. - `c = M(ht)` is the first Cluster of the sequence requiring the halo - exchange `ht` to have terminated before the execution can proceed. + * Bind HaloTouch to Clusters. A Cluster carrying a HaloTouch cannot execute + before the HaloExchange has completed. * Lower the CriticalRegions: * If they encode an asynchronous operation (e.g., a WaitLock), attach it to a Nop Cluster for future lowering; @@ -156,6 +155,9 @@ def preprocess(clusters, options=None, **kwargs): hs = HaloScheme.union(e.rhs.halo_scheme for e in c.exprs) queue.append(c.rebuild(exprs=[], halo_scheme=hs)) + elif c.is_dist_reduce: + processed.append(c) + elif c.is_critical_region and c.syncs: processed.append(c.rebuild(exprs=None, guards=c.guards, syncs=c.syncs)) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index fb125a9594..c435800872 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -990,6 +990,13 @@ def reads(self): """ return as_mapper(self.reads_gen(), key=lambda i: i.function) + @cached_property + def read_only(self): + """ + Create a mapper from functions to read accesses. + """ + return set(self.reads) - set(self.writes) + @cached_property def initialized(self): return frozenset(e.lhs.function for e in self.exprs diff --git a/devito/mpi/reduction_scheme.py b/devito/mpi/reduction_scheme.py new file mode 100644 index 0000000000..bf2ccb0bbd --- /dev/null +++ b/devito/mpi/reduction_scheme.py @@ -0,0 +1,44 @@ +import sympy + +from devito.tools import Reconstructable + +__all__ = ['DistributedReduction'] + + +class DistributedReduction(sympy.Function, Reconstructable): + + """ + A SymPy object representing a distributed Reduction. + """ + + __rargs__ = ('var',) + __rkwargs__ = ('op', 'grid') + + def __new__(cls, var, op=None, grid=None, **kwargs): + obj = sympy.Function.__new__(cls, var, **kwargs) + obj.op = op + obj.grid = grid + return obj + + def __repr__(self): + return "DistributedReduction(%s,%s)" % (self.var, self.op) + + __str__ = __repr__ + + def _sympystr(self, printer): + return str(self) + + def _hashable_content(self): + return (self.op, self.grid) + + def __eq__(self, other): + return (isinstance(other, DistributedReduction) and + self.var == other.var and + self.op == other.op and + self.grid == other.grid) + + func = Reconstructable._rebuild + + @property + def var(self): + return self.args[0] diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index d3b6011ff1..2968a2d4fb 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -8,7 +8,7 @@ from sympy import Integer from devito.data import OWNED, HALO, NOPAD, LEFT, CENTER, RIGHT -from devito.ir.equations import DummyEq +from devito.ir.equations import DummyEq, OpInc, OpMin, OpMax from devito.ir.iet import (Call, Callable, Conditional, ElementalFunction, Expression, ExpressionBundle, AugmentedExpression, Iteration, List, Prodder, Return, make_efunc, FindNodes, @@ -21,13 +21,13 @@ from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, CompositeObject, CustomDimension) -__all__ = ['HaloExchangeBuilder', 'mpi_registry'] +__all__ = ['HaloExchangeBuilder', 'ReductionBuilder, ''mpi_registry'] class HaloExchangeBuilder: """ - Build IET-based routines to implement MPI halo exchange. + Build IET routines to generate MPI halo exchanges. """ def __new__(cls, mpimode, generators=None, rcompile=None, sregistry=None, **kwargs): @@ -1351,3 +1351,38 @@ def _arg_values(self, args=None, **kwargs): except AttributeError: setattr(entry, a.name, mapper[a][0]) return values + + +class AllreduceCall(Call): + + def __init__(self, arguments, **kwargs): + super().__init__('MPI_Allreduce', arguments) + + +class ReductionBuilder(object): + + """ + Build IET routines performing MPI reductions. + """ + + mapper = { + OpInc: 'MPI_SUM', + OpMax: 'MPI_MAX', + OpMin: 'MPI_MIN', + } + + def make(self, dr): + """ + Construct Callables and Calls implementing distributed-memory reductions. + """ + f = dr.var + comm = dr.grid.distributor._obj_comm + + inplace = Macro('MPI_IN_PLACE') + mpitype = Macro(dtype_to_mpitype(f.dtype)) + op = self.mapper[dr.op] + + arguments = [inplace, Byref(f), Integer(1), mpitype, op, comm] + allreduce = AllreduceCall(arguments) + + return allreduce diff --git a/devito/operator/operator.py b/devito/operator/operator.py index c9e84ac1c0..d1d2daeb5c 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1049,7 +1049,7 @@ def __setstate__(self, state): # if applied in cascade (e.g., `linearization` on top of `linearization`) rcompile_registry = { 'avoid_denormals': False, - 'mpi': False, + #'mpi': False, #TODO: DROP / DON'T DROP?? NEED IT FOR GLB REDUCTIONS... 'linearize': False, 'place-transfers': False } diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 0ac3ab51d5..1a1e4d122a 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -7,7 +7,8 @@ retrieve_iteration_tree) from devito.ir.support import PARALLEL, Scope from devito.mpi.halo_scheme import HaloScheme -from devito.mpi.routines import HaloExchangeBuilder +from devito.mpi.reduction_scheme import DistributedReduction +from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder from devito.passes.iet.engine import iet_pass from devito.tools import generator @@ -298,12 +299,13 @@ def _mark_overlappable(iet): @iet_pass def make_mpi(iet, mpimode=None, **kwargs): """ - Inject MPI Callables and Calls implementing halo exchanges for + Inject MPI Callables and Calls implementing halo exchanges and reductions for distributed-memory parallelism. """ # To produce unique object names generators = {'msg': generator(), 'comm': generator(), 'comp': generator()} + # Halo exchanges sync_heb = HaloExchangeBuilder('basic', generators, **kwargs) user_heb = HaloExchangeBuilder(mpimode, generators, **kwargs) mapper = {} @@ -328,6 +330,12 @@ def make_mpi(iet, mpimode=None, **kwargs): break iet = Transformer(mapper, nested=True).visit(iet) + # Reductions + rb = ReductionBuilder() + mapper = {e: rb.make(e.expr.rhs) for e in FindNodes(Expression).visit(iet) + if isinstance(e.expr.rhs, DistributedReduction)} + iet = Transformer(mapper, nested=True).visit(iet) + return iet, {'includes': ['mpi.h'], 'efuncs': efuncs} diff --git a/tests/test_error_checking.py b/tests/test_error_checking.py index 061cb2b575..22d36d01bf 100644 --- a/tests/test_error_checking.py +++ b/tests/test_error_checking.py @@ -25,3 +25,26 @@ def test_stability(expr): with pytest.raises(ExecutionError): op.apply(time_M=200, dt=.1) + + +@switchconfig(safe_math=True) +@pytest.mark.parallel(mode=2) +def test_stability_mpi(): + grid = Grid(shape=(10, 10)) + + f = Function(name='f', grid=grid, space_order=2) # noqa + u = TimeFunction(name='u', grid=grid, space_order=2) + v = TimeFunction(name='v', grid=grid, space_order=2) + + eq = Eq(u.forward, u/f) + + op = Operator(eq, opt=('advanced', {'errctl': 'max'})) + + # Check generated code + assert 'MPI_Allreduce' in str(op) + + u.data[:] = 1. + v.data[:] = 2. + + with pytest.raises(ExecutionError): + op.apply(time_M=200, dt=.1) From a07c0e74b4c59f09c0ec47e384d7b709c2ea9ee4 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 4 Apr 2024 09:12:38 +0000 Subject: [PATCH 2/5] compiler: Tweak SparseFunction reconstruction --- tests/test_mpi.py | 7 +++++-- tests/test_sparse.py | 13 ++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 2e4a7fcca7..2bb2ed04d7 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2637,14 +2637,17 @@ def run_adjoint_F(self, nd): solver = acoustic_setup(shape=shape, spacing=[15. for _ in shape], tn=tn, space_order=so, nrec=nrec, preset='layers-isotropic', dtype=np.float64) + # Run forward operator - rec, u, _ = solver.forward() + src = solver.geometry.src + rec, u, _ = solver.forward(src=src) assert np.isclose(norm(u) / Eu, 1.0) assert np.isclose(norm(rec) / Erec, 1.0) # Run adjoint operator - srca, v, _ = solver.adjoint(rec=rec) + srca = src.func(name='srca') + srca, v, _ = solver.adjoint(srca=srca, rec=rec) assert np.isclose(norm(v) / Ev, 1.0) assert np.isclose(norm(srca) / Esrca, 1.0) diff --git a/tests/test_sparse.py b/tests/test_sparse.py index c63a104cd7..1dcb7df2eb 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -417,10 +417,9 @@ def test_rebuild(self, sptype): assert getattr(sp, subf).name.startswith("s_") # Rebuild with different name, this should drop the function - # and create new data + # and create new data, while the coordinates and more generally all + # SubFunctions remain the same sp2 = sp._rebuild(name="sr") - - # Check new subfunction for subf in sp2._sub_functions: if getattr(sp2, subf) is not None: assert getattr(sp2, subf) == getattr(sp, subf) @@ -432,6 +431,14 @@ def test_rebuild(self, sptype): assert getattr(sp2, subf).name.startswith("sr2_") assert getattr(sp2, subf).data is None + # Rebuild with different name and dimensions. This is expected to recreate + # the SubFunctions as well + sp2 = sp._rebuild(name="sr3", dimensions=None) + for subf in sp2._sub_functions: + if getattr(sp2, subf) is not None: + assert getattr(sp2, subf).name.startswith("sr3_") + assert np.all(getattr(sp2, subf).data == 0) + @pytest.mark.parametrize('sptype', _sptypes) def test_subs(self, sptype): grid = Grid((3, 3, 3)) From 1f16737feea5f36492f123c65ff231fe009555b4 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 2 Apr 2024 15:00:48 +0000 Subject: [PATCH 3/5] compiler: Add support for C-level MPI_Allreduce --- devito/builtins/arithmetic.py | 60 +++++++++++++------------ devito/builtins/utils.py | 52 ++++++---------------- devito/core/gpu.py | 19 ++++---- devito/ir/clusters/algorithms.py | 75 ++++++++++++-------------------- devito/ir/clusters/cluster.py | 14 +++--- devito/mpi/reduction_scheme.py | 20 +++++---- devito/mpi/routines.py | 7 +-- devito/operator/operator.py | 21 +++++---- devito/passes/iet/langbase.py | 1 - devito/passes/iet/mpi.py | 43 ++++++++++-------- tests/test_builtins.py | 14 +++++- tests/test_mpi.py | 6 ++- 12 files changed, 161 insertions(+), 171 deletions(-) diff --git a/devito/builtins/arithmetic.py b/devito/builtins/arithmetic.py index bb0e31806d..f24a0e56ea 100644 --- a/devito/builtins/arithmetic.py +++ b/devito/builtins/arithmetic.py @@ -1,7 +1,7 @@ import numpy as np import devito as dv -from devito.builtins.utils import MPIReduction +from devito.builtins.utils import make_retval __all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax'] @@ -44,15 +44,15 @@ def norm(f, order=2): p, eqns = f.guard() if f.is_SparseFunction else (f, []) dtype = accumulator_mapper[f.dtype] + n = make_retval(f.grid, dtype) s = dv.types.Symbol(name='sum', dtype=dtype) - with MPIReduction(f, dtype=dtype) as mr: - op = dv.Operator([dv.Eq(s, 0.0)] + eqns + - [dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(mr.n[0], s)], - name='norm%d' % order) - op.apply(**kwargs) + op = dv.Operator([dv.Eq(s, 0.0)] + eqns + + [dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(n[0], s)], + name='norm%d' % order) + op.apply(**kwargs) - v = np.power(mr.v, 1/order) + v = np.power(n.data[0], 1/order) return f.dtype(v) @@ -129,15 +129,15 @@ def sumall(f): p, eqns = f.guard() if f.is_SparseFunction else (f, []) dtype = accumulator_mapper[f.dtype] + n = make_retval(f.grid, dtype) s = dv.types.Symbol(name='sum', dtype=dtype) - with MPIReduction(f, dtype=dtype) as mr: - op = dv.Operator([dv.Eq(s, 0.0)] + eqns + - [dv.Inc(s, p), dv.Eq(mr.n[0], s)], - name='sum') - op.apply(**kwargs) + op = dv.Operator([dv.Eq(s, 0.0)] + eqns + + [dv.Inc(s, p), dv.Eq(n[0], s)], + name='sum') + op.apply(**kwargs) - return f.dtype(mr.v) + return f.dtype(n.data[0]) @dv.switchconfig(log_level='ERROR') @@ -184,15 +184,15 @@ def inner(f, g): rhs, eqns = f.guard(f*g) if f.is_SparseFunction else (f*g, []) dtype = accumulator_mapper[f.dtype] + n = make_retval(f.grid or g.grid, dtype) s = dv.types.Symbol(name='sum', dtype=dtype) - with MPIReduction(f, g, dtype=dtype) as mr: - op = dv.Operator([dv.Eq(s, 0.0)] + eqns + - [dv.Inc(s, rhs), dv.Eq(mr.n[0], s)], - name='inner') - op.apply(**kwargs) + op = dv.Operator([dv.Eq(s, 0.0)] + eqns + + [dv.Inc(s, rhs), dv.Eq(n[0], s)], + name='inner') + op.apply(**kwargs) - return f.dtype(mr.v) + return f.dtype(n.data[0]) @dv.switchconfig(log_level='ERROR') @@ -208,11 +208,14 @@ def mmin(f): if isinstance(f, dv.Constant): return f.data elif isinstance(f, dv.types.dense.DiscreteFunction): - with MPIReduction(f, op=dv.mpi.MPI.MIN) as mr: - mr.n.data[0] = np.min(f.data_ro_domain).item() - return mr.v.item() + v = np.min(f.data_ro_domain) + if f.grid is None or not dv.configuration['mpi']: + return v.item() + else: + comm = f.grid.distributor.comm + return comm.allreduce(v, dv.mpi.MPI.MIN).item() else: - raise ValueError("Expected Function, not `%s`" % type(f)) + raise ValueError("Expected Function, got `%s`" % type(f)) @dv.switchconfig(log_level='ERROR') @@ -228,8 +231,11 @@ def mmax(f): if isinstance(f, dv.Constant): return f.data elif isinstance(f, dv.types.dense.DiscreteFunction): - with MPIReduction(f, op=dv.mpi.MPI.MAX) as mr: - mr.n.data[0] = np.max(f.data_ro_domain).item() - return mr.v.item() + v = np.max(f.data_ro_domain) + if f.grid is None or not dv.configuration['mpi']: + return v.item() + else: + comm = f.grid.distributor.comm + return comm.allreduce(v, dv.mpi.MPI.MAX).item() else: - raise ValueError("Expected Function, not `%s`" % type(f)) + raise ValueError("Expected Function, got `%s`" % type(f)) diff --git a/devito/builtins/utils.py b/devito/builtins/utils.py index fe5e0cdb9d..786dbbce48 100644 --- a/devito/builtins/utils.py +++ b/devito/builtins/utils.py @@ -1,52 +1,26 @@ from functools import wraps -import numpy as np - import devito as dv from devito.symbolics import uxreplace from devito.tools import as_tuple -__all__ = ['MPIReduction', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args'] +__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args'] -class MPIReduction: +def make_retval(grid, dtype): """ - A context manager to build MPI-aware reduction Operators. + Devito does not support passing values by reference. This function + creates a dummy Function of size 1 to store the return value of a builtin + applied to `f`. """ - - def __init__(self, *functions, op=dv.mpi.MPI.SUM, dtype=None): - grids = {f.grid for f in functions} - if len(grids) == 0: - self.grid = None - elif len(grids) == 1: - self.grid = grids.pop() - else: - raise ValueError("Multiple Grids found") - if dtype is not None: - self.dtype = dtype - else: - dtype = {f.dtype for f in functions} - if len(dtype) == 1: - self.dtype = np.result_type(dtype.pop(), np.float32).type - else: - raise ValueError("Illegal mixed data types") - self.v = None - self.op = op - - def __enter__(self): - i = dv.Dimension(name='mri',) - self.n = dv.Function(name='n', shape=(1,), dimensions=(i,), - grid=self.grid, dtype=self.dtype, space='host') - self.n.data[:] = 0 - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.grid is None or not dv.configuration['mpi']: - assert self.n.data.size == 1 - self.v = self.n.data[0] - else: - comm = self.grid.distributor.comm - self.v = comm.allreduce(np.asarray(self.n.data), self.op)[0] + if grid is None: + raise ValueError("Expected Grid, got None") + + i = dv.Dimension(name='mri',) + n = dv.Function(name='n', shape=(1,), dimensions=(i,), grid=grid, + dtype=dtype, space='host') + n.data[:] = 0 + return n def nbl_to_padsize(nbl, ndim): diff --git a/devito/core/gpu.py b/devito/core/gpu.py index 8d7ea75195..266c198647 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -116,19 +116,20 @@ def _normalize_gpu_fit(cls, oo, **kwargs): return as_tuple(cls.GPU_FIT) @classmethod - def _rcompile_wrapper(cls, **kwargs0): - options = kwargs0['options'] + def _rcompile_wrapper(cls, **kwargs): + def wrapper(expressions, mode='default', **options): - def wrapper(expressions, mode='default', **kwargs1): if mode == 'host': - kwargs = {**{ + par_disabled = kwargs['options']['par-disabled'] + target = { 'platform': 'cpu64', - 'language': 'C' if options['par-disabled'] else 'openmp', - 'compiler': 'custom', - }, **kwargs1} + 'language': 'C' if par_disabled else 'openmp', + 'compiler': 'custom' + } else: - kwargs = {**kwargs0, **kwargs1} - return rcompile(expressions, kwargs) + target = None + + return rcompile(expressions, kwargs, options, target=target) return wrapper diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index bce705b40a..b2b552530d 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -14,11 +14,11 @@ from devito.ir.clusters.cluster import Cluster, ClusterGroup from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass from devito.mpi.halo_scheme import HaloScheme, HaloTouch -from devito.mpi.reduction_scheme import DistributedReduction +from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace, xreplace_indices) from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten, - is_integer, timed_pass, toposort) + is_integer, split, timed_pass, toposort) from devito.types import Array, Eq, Symbol from devito.types.dimension import BOTTOM, ModuloDimension @@ -378,12 +378,10 @@ def communications(clusters): return clusters -class Comms(Queue): - #TODO: MAYBE DROP ME +class HaloComms(Queue): """ - Abstract base class for injecting Clusters representing communications - for distributed-memory parallelism. + Inject Clusters representing halo exchanges for distributed-memory parallelism. """ _q_guards_in_key = True @@ -391,13 +389,6 @@ class Comms(Queue): B = Symbol(name='⊥') - -class HaloComms(Comms): - - """ - A specialization of Comms to handle halo exchanges. - """ - def process(self, clusters): return self._process_fatd(clusters, 1, seen=set()) @@ -453,49 +444,41 @@ def callback(self, clusters, prefix, seen=None): def reduction_comms(clusters): - # Detect the underlying Grid - #TODO: pretty rudimentary, but it's a start - for c in clusters: - try: - grid = c.grid - break - except ValueError: - continue - else: - return clusters - - # Detect global reductions along the distributed Dimensions - found = {} + processed = [] + fifo = [] for c in clusters: - if not any(grid.is_distributed(d) for d in c.ispace.itdims): - continue - + # Schedule the global reductions encountered before `c`, if the + # IterationSpace of `c` is such that the reduction can be carried out + found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace)) + if found: + exprs = [Eq(dr.var, dr) for dr in found] + processed.append(c.rebuild(exprs=exprs)) + + # Detect the global reductions in `c` for e in c.exprs: op = e.operation - if op is None: + if op is None or c.is_sparse: continue - elif found.get(e.lhs, op) != op: - raise ValueError("Inconsistent reduction operations") - else: - found[e.lhs] = e.operation - # Place global reductions right before they're required - processed = [] - for c in clusters: - for var, op in list(found.items()): - if var in c.scope.read_only: - expr = Eq(var, DistributedReduction(var, op=op, grid=grid)) - processed.append(c.rebuild(exprs=expr)) + var = e.lhs + grid = c.grid + if grid is None: + continue + + # The IterationSpace within which the global reduction is carried out + ispace = c.ispace.project(lambda d: d in var.free_symbols) + if ispace.itdims == c.ispace.itdims: + # Inc/Max/Min/... being used for a non-reduction operation + continue - found.pop(var) + fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace)) processed.append(c) # Leftover reductions are placed at the very end - while found: - var, op = found.popitem() - expr = Eq(var, DistributedReduction(var, op=op, grid=grid)) - processed.append(Cluster(exprs=[expr], ispace=null_ispace)) + if fifo: + exprs = [Eq(dr.var, dr) for dr in fifo] + processed.append(Cluster(exprs=exprs, ispace=null_ispace)) return processed diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 5e56be055a..a429e4714f 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -11,7 +11,7 @@ normalize_properties, normalize_syncs, minimum, maximum, null_ispace) from devito.mpi.halo_scheme import HaloScheme, HaloTouch -from devito.mpi.reduction_scheme import DistributedReduction +from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import estimate_cost from devito.tools import as_tuple, flatten, frozendict, infer_dtype from devito.types import WeakFence, CriticalRegion @@ -181,8 +181,11 @@ def has_increments(self): @cached_property def grid(self): - grids = set(f.grid for f in self.functions if f.is_DiscreteFunction) - {None} - if len(grids) == 1: + grids = set(f.grid for f in self.functions if f.is_AbstractFunction) + grids.discard(None) + if len(grids) == 0: + return None + elif len(grids) == 1: return grids.pop() else: raise ValueError("Cluster has no unique Grid") @@ -211,7 +214,7 @@ def is_dense(self): dims = {d for d in self.properties if d._defines & target} if any(pset & self.properties[d] for d in dims): return True - except ValueError: + except (AttributeError, ValueError): pass # Fallback to legacy is_dense checks @@ -241,8 +244,7 @@ def is_halo_touch(self): @property def is_dist_reduce(self): - return self.exprs and all(isinstance(e.rhs, DistributedReduction) - for e in self.exprs) + return self.exprs and all(isinstance(e.rhs, DistReduce) for e in self.exprs) @property def is_fence(self): diff --git a/devito/mpi/reduction_scheme.py b/devito/mpi/reduction_scheme.py index bf2ccb0bbd..f3a412f07d 100644 --- a/devito/mpi/reduction_scheme.py +++ b/devito/mpi/reduction_scheme.py @@ -2,26 +2,27 @@ from devito.tools import Reconstructable -__all__ = ['DistributedReduction'] +__all__ = ['DistReduce'] -class DistributedReduction(sympy.Function, Reconstructable): +class DistReduce(sympy.Function, Reconstructable): """ A SymPy object representing a distributed Reduction. """ __rargs__ = ('var',) - __rkwargs__ = ('op', 'grid') + __rkwargs__ = ('op', 'grid', 'ispace') - def __new__(cls, var, op=None, grid=None, **kwargs): + def __new__(cls, var, op=None, grid=None, ispace=None, **kwargs): obj = sympy.Function.__new__(cls, var, **kwargs) obj.op = op obj.grid = grid + obj.ispace = ispace return obj def __repr__(self): - return "DistributedReduction(%s,%s)" % (self.var, self.op) + return "DistReduce(%s,%s)" % (self.var, self.op) __str__ = __repr__ @@ -29,13 +30,16 @@ def _sympystr(self, printer): return str(self) def _hashable_content(self): - return (self.op, self.grid) + return (self.op, self.grid, self.ispace) def __eq__(self, other): - return (isinstance(other, DistributedReduction) and + return (isinstance(other, DistReduce) and self.var == other.var and self.op == other.op and - self.grid == other.grid) + self.grid == other.grid and + self.ispace == other.ispace) + + __hash__ = sympy.Function.__hash__ func = Reconstructable._rebuild diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 2968a2d4fb..46fc2a8e3a 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -21,7 +21,7 @@ from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, CompositeObject, CustomDimension) -__all__ = ['HaloExchangeBuilder', 'ReductionBuilder, ''mpi_registry'] +__all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry'] class HaloExchangeBuilder: @@ -30,7 +30,8 @@ class HaloExchangeBuilder: Build IET routines to generate MPI halo exchanges. """ - def __new__(cls, mpimode, generators=None, rcompile=None, sregistry=None, **kwargs): + def __new__(cls, mpimode, generators=None, rcompile=None, sregistry=None, + **kwargs): obj = object.__new__(mpi_registry[mpimode]) obj.rcompile = rcompile @@ -370,7 +371,7 @@ def _make_copy(self, f, hse, key, swap=False): eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices]))) # Compile `eqns` into an IET via recursive compilation - irs, _ = self.rcompile(eqns) + irs, _ = self.rcompile(eqns, mpi=False) parameters = [buf] + bshape + list(f.handles) + ofs diff --git a/devito/operator/operator.py b/devito/operator/operator.py index d1d2daeb5c..5c781512b9 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -271,9 +271,9 @@ def _lower(cls, expressions, **kwargs): return IRs(expressions, clusters, stree, uiet, iet), byproduct @classmethod - def _rcompile_wrapper(cls, **kwargs0): - def wrapper(expressions, **kwargs1): - return rcompile(expressions, {**kwargs0, **kwargs1}) + def _rcompile_wrapper(cls, **kwargs): + def wrapper(expressions, **options): + return rcompile(expressions, kwargs, options) return wrapper @classmethod @@ -1049,26 +1049,25 @@ def __setstate__(self, state): # if applied in cascade (e.g., `linearization` on top of `linearization`) rcompile_registry = { 'avoid_denormals': False, - #'mpi': False, #TODO: DROP / DON'T DROP?? NEED IT FOR GLB REDUCTIONS... 'linearize': False, 'place-transfers': False } -def rcompile(expressions, kwargs=None): +def rcompile(expressions, kwargs, options, target=None): """ Perform recursive compilation on an ordered sequence of symbolic expressions. """ - if not kwargs or 'options' not in kwargs: - kwargs = parse_kwargs(**kwargs) + options = {**kwargs['options'], **rcompile_registry, **options} + + if target is None: cls = operator_selector(**kwargs) - kwargs = cls._normalize_kwargs(**kwargs) else: + kwargs = parse_kwargs(**target) cls = operator_selector(**kwargs) + kwargs = cls._normalize_kwargs(**kwargs) - # Tweak the compilation kwargs - options = dict(kwargs['options']) - options.update(rcompile_registry) + # Use the customized opt options kwargs['options'] = options # Recursive profiling not supported -- would be a complete mess diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index 0331760be8..d27674c419 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -432,7 +432,6 @@ def _(iet): break except AttributeError: pass - assert objcomm is not None devicetype = as_list(self.lang[self.platform]) deviceid = self.deviceid diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 1a1e4d122a..3a714e095a 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -7,7 +7,7 @@ retrieve_iteration_tree) from devito.ir.support import PARALLEL, Scope from devito.mpi.halo_scheme import HaloScheme -from devito.mpi.reduction_scheme import DistributedReduction +from devito.mpi.reduction_scheme import DistReduce from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder from devito.passes.iet.engine import iet_pass from devito.tools import generator @@ -297,15 +297,13 @@ def _mark_overlappable(iet): @iet_pass -def make_mpi(iet, mpimode=None, **kwargs): +def make_halo_exchanges(iet, mpimode=None, **kwargs): """ - Inject MPI Callables and Calls implementing halo exchanges and reductions for - distributed-memory parallelism. + Lower HaloSpots into halo exchanges for distributed-memory parallelism. """ # To produce unique object names generators = {'msg': generator(), 'comm': generator(), 'comp': generator()} - # Halo exchanges sync_heb = HaloExchangeBuilder('basic', generators, **kwargs) user_heb = HaloExchangeBuilder(mpimode, generators, **kwargs) mapper = {} @@ -330,26 +328,33 @@ def make_mpi(iet, mpimode=None, **kwargs): break iet = Transformer(mapper, nested=True).visit(iet) - # Reductions + return iet, {'includes': ['mpi.h'], 'efuncs': efuncs} + + +@iet_pass +def make_reductions(iet, mpimode=None, **kwargs): rb = ReductionBuilder() - mapper = {e: rb.make(e.expr.rhs) for e in FindNodes(Expression).visit(iet) - if isinstance(e.expr.rhs, DistributedReduction)} + + mapper = {} + for e in FindNodes(Expression).visit(iet): + if not isinstance(e.expr.rhs, DistReduce): + continue + elif mpimode: + mapper[e] = rb.make(e.expr.rhs) + else: + mapper[e] = None iet = Transformer(mapper, nested=True).visit(iet) - return iet, {'includes': ['mpi.h'], 'efuncs': efuncs} + return iet, {} def mpiize(graph, **kwargs): """ - Perform two IET passes: + Perform three IET passes: - * Optimization of communications - * Injection of MPI code - - The former is implemented by manipulating HaloSpots. - - The latter resorts to creating MPI Callables and replacing HaloSpots with Calls - to MPI Callables. + * Optimization of halo exchanges + * Injection of code for halo exchanges + * Injection of code for reductions """ options = kwargs['options'] @@ -358,4 +363,6 @@ def mpiize(graph, **kwargs): mpimode = options['mpi'] if mpimode: - make_mpi(graph, mpimode=mpimode, **kwargs) + make_halo_exchanges(graph, mpimode=mpimode, **kwargs) + + make_reductions(graph, mpimode=mpimode, **kwargs) diff --git a/tests/test_builtins.py b/tests/test_builtins.py index 21b4ca0830..32e60c0912 100644 --- a/tests/test_builtins.py +++ b/tests/test_builtins.py @@ -448,7 +448,7 @@ def test_sum_sparse(self): def test_min_max_sparse(self): """ - Test that mmin/mmax work on SparseFunction + Test that mmin/mmax work on SparseFunction. """ grid = Grid((101, 101), extent=(1000., 1000.)) @@ -464,6 +464,18 @@ def test_min_max_sparse(self): term2 = mmax(rec0) assert np.isclose(term1/term2 - 1, 0.0, rtol=0.0, atol=1e-5) + @pytest.mark.parallel(mode=4) + def test_min_max_mpi(self): + grid = Grid(shape=(100, 100)) + + f = Function(name='f', grid=grid) + + # Populate data with increasing values starting at 1 + f.data[:] = np.arange(1, 10001).reshape((100, 100)) + + assert mmin(f) == 1 + assert mmax(f) == 10000 + def test_issue_1860(self): grid = Grid(shape=(401, 301, 181)) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 2bb2ed04d7..d7bd5188d5 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -13,7 +13,8 @@ from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, retrieve_iteration_tree) from devito.mpi import MPI -from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall, ComputeCall +from devito.mpi.routines import (HaloUpdateCall, HaloUpdateList, MPICall, + ComputeCall, AllreduceCall) from devito.mpi.distributed import CustomTopology from devito.tools import Bunch @@ -928,7 +929,8 @@ def test_avoid_haloupdate_as_nostencil_advanced(self, mode): # No stencil in the expressions, so no halo update required! calls = FindNodes(Call).visit(op) - assert len(calls) == 0 + assert len(calls) == 2 + assert all(isinstance(i, AllreduceCall) for i in calls) @pytest.mark.parallel(mode=1) def test_avoid_redundant_haloupdate(self, mode): From 05129060c8a73bf17429531e808e213c014560c8 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 5 Apr 2024 08:03:58 +0000 Subject: [PATCH 4/5] compiler: Fix detection of global distributed reductions --- devito/ir/clusters/algorithms.py | 21 ++++++++++++++++----- devito/ir/clusters/cluster.py | 13 +++++++++++++ devito/ir/support/space.py | 2 +- tests/test_mpi.py | 13 ++++++------- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index b2b552530d..e3f815a2f4 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -447,14 +447,14 @@ def reduction_comms(clusters): processed = [] fifo = [] for c in clusters: - # Schedule the global reductions encountered before `c`, if the - # IterationSpace of `c` is such that the reduction can be carried out + # Schedule the global distributed reductions encountered before `c`, + # if `c`'s IterationSpace is such that the reduction can be carried out found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace)) if found: exprs = [Eq(dr.var, dr) for dr in found] processed.append(c.rebuild(exprs=exprs)) - # Detect the global reductions in `c` + # Detect the global distributed reductions in `c` for e in c.exprs: op = e.operation if op is None or c.is_sparse: @@ -465,12 +465,23 @@ def reduction_comms(clusters): if grid is None: continue - # The IterationSpace within which the global reduction is carried out + # Is Inc/Max/Min/... actually used for a reduction? ispace = c.ispace.project(lambda d: d in var.free_symbols) if ispace.itdims == c.ispace.itdims: - # Inc/Max/Min/... being used for a non-reduction operation continue + # The reduced Dimensions + rdims = set(c.ispace.itdims) - set(ispace.itdims) + + # The reduced Dimensions inducing a global distributed reduction + grdims = {d for d in rdims if d._defines & c.dist_dimensions} + if not grdims: + continue + + # The IterationSpace within which the global distributed reduction + # must be carried out + ispace = c.ispace.prefix(lambda d: d in var.free_symbols) + fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace)) processed.append(c) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index a429e4714f..3179cc44ae 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -167,6 +167,19 @@ def used_dimensions(self): idims = set.union(*[set(e.implicit_dims) for e in self.exprs]) return {i for i in self.free_symbols if i.is_Dimension} | idims + @cached_property + def dist_dimensions(self): + """ + The Cluster's distributed Dimensions. + """ + ret = set() + for f in self.functions: + try: + ret.update(f._dist_dimensions) + except AttributeError: + pass + return frozenset(ret) + @cached_property def scope(self): return Scope(self.exprs) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 34f9a110e7..dc7b26ec7c 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -954,7 +954,7 @@ def prefix(self, key): try: i = self.project(key)[-1] except IndexError: - return None + return null_ispace return self[:self.index(i.dim) + 1] diff --git a/tests/test_mpi.py b/tests/test_mpi.py index d7bd5188d5..7a4257faea 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -4,17 +4,17 @@ from conftest import _R, assert_blocking, assert_structure from devito import (Grid, Constant, Function, TimeFunction, SparseFunction, - SparseTimeFunction, Dimension, ConditionalDimension, SubDimension, - SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration, - switchconfig, generic_derivative, PrecomputedSparseFunction, - DefaultDimension) + SparseTimeFunction, Dimension, ConditionalDimension, + SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm, + inner, configuration, switchconfig, generic_derivative, + PrecomputedSparseFunction, DefaultDimension) from devito.arch.compiler import OneapiCompiler from devito.data import LEFT, RIGHT from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, retrieve_iteration_tree) from devito.mpi import MPI from devito.mpi.routines import (HaloUpdateCall, HaloUpdateList, MPICall, - ComputeCall, AllreduceCall) + ComputeCall) from devito.mpi.distributed import CustomTopology from devito.tools import Bunch @@ -929,8 +929,7 @@ def test_avoid_haloupdate_as_nostencil_advanced(self, mode): # No stencil in the expressions, so no halo update required! calls = FindNodes(Call).visit(op) - assert len(calls) == 2 - assert all(isinstance(i, AllreduceCall) for i in calls) + assert len(calls) == 0 @pytest.mark.parallel(mode=1) def test_avoid_redundant_haloupdate(self, mode): From 35aea7ca497d48e80525f7be3eecf5e1c7ce12c3 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 24 May 2024 11:41:04 -0400 Subject: [PATCH 5/5] CI: leftover parallel mark missing --- devito/types/sparse.py | 5 ++++- tests/test_builtins.py | 2 +- tests/test_error_checking.py | 2 +- tests/test_sparse.py | 7 +++++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/devito/types/sparse.py b/devito/types/sparse.py index 186ec03935..d6845a4dc5 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -80,7 +80,10 @@ def __indices_setup__(cls, *args, **kwargs): else: sparse_dim = Dimension(name='p_%s' % kwargs["name"]) - dimensions = as_tuple(kwargs.get('dimensions', sparse_dim)) + dimensions = as_tuple(kwargs.get('dimensions')) + if not dimensions: + dimensions = (sparse_dim,) + if args: return tuple(dimensions), tuple(args) else: diff --git a/tests/test_builtins.py b/tests/test_builtins.py index 32e60c0912..102875d35a 100644 --- a/tests/test_builtins.py +++ b/tests/test_builtins.py @@ -465,7 +465,7 @@ def test_min_max_sparse(self): assert np.isclose(term1/term2 - 1, 0.0, rtol=0.0, atol=1e-5) @pytest.mark.parallel(mode=4) - def test_min_max_mpi(self): + def test_min_max_mpi(self, mode): grid = Grid(shape=(100, 100)) f = Function(name='f', grid=grid) diff --git a/tests/test_error_checking.py b/tests/test_error_checking.py index 22d36d01bf..f03cf708e4 100644 --- a/tests/test_error_checking.py +++ b/tests/test_error_checking.py @@ -29,7 +29,7 @@ def test_stability(expr): @switchconfig(safe_math=True) @pytest.mark.parallel(mode=2) -def test_stability_mpi(): +def test_stability_mpi(mode): grid = Grid(shape=(10, 10)) f = Function(name='f', grid=grid, space_order=2) # noqa diff --git a/tests/test_sparse.py b/tests/test_sparse.py index 1dcb7df2eb..78613beaba 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -426,6 +426,8 @@ def test_rebuild(self, sptype): # Rebuild with different name as an alias sp2 = sp._rebuild(name="sr2", alias=True) + assert sp2.name == "sr2" + assert sp2.dimensions == sp.dimensions for subf in sp2._sub_functions: if getattr(sp2, subf) is not None: assert getattr(sp2, subf).name.startswith("sr2_") @@ -434,10 +436,11 @@ def test_rebuild(self, sptype): # Rebuild with different name and dimensions. This is expected to recreate # the SubFunctions as well sp2 = sp._rebuild(name="sr3", dimensions=None) + assert sp2.name == "sr3" + assert sp2.dimensions == sp.dimensions for subf in sp2._sub_functions: if getattr(sp2, subf) is not None: - assert getattr(sp2, subf).name.startswith("sr3_") - assert np.all(getattr(sp2, subf).data == 0) + assert getattr(sp2, subf) == getattr(sp, subf) @pytest.mark.parametrize('sptype', _sptypes) def test_subs(self, sptype):