diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 76876d687a..86e432ad71 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1438,8 +1438,20 @@ def DummyExpr(*args, init=False): # Nodes required for distributed-memory halo exchange +class HaloMixin: -class HaloSpot(Node): + def __repr__(self): + fstrings = [] + for f in self.fmapper.keys(): + loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values())) + loc_indices_str = str(list(loc_indices)) if loc_indices else "" + fstrings.append("%s%s" % (f.name, loc_indices_str)) + + functions = ",".join(fstrings) + return "<%s(%s)>" % (self.__class__.__name__, functions) + + +class HaloSpot(HaloMixin, Node): """ A halo exchange operation (e.g., send, recv, wait, ...) required to @@ -1464,16 +1476,6 @@ def __init__(self, body, halo_scheme): self._halo_scheme = halo_scheme - def __repr__(self): - fstrings = [] - for f in self.functions: - loc_indices = OrderedSet(*(self.halo_scheme.fmapper[f].loc_indices.values())) - loc_indices_str = str(list(loc_indices)) if loc_indices else "" - fstrings.append("%s%s" % (f.name, loc_indices_str)) - - functions = ",".join(fstrings) - return "<%s(%s)>" % (self.__class__.__name__, functions) - @property def halo_scheme(self): return self._halo_scheme diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index cf8edadf11..92bfa92d03 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -9,6 +9,7 @@ from devito import configuration from devito.data import CORE, OWNED, LEFT, CENTER, RIGHT from devito.ir.support import Forward, Scope +from devito.ir.iet.nodes import HaloMixin from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten, frozendict, is_integer, filter_sorted, OrderedSet) @@ -62,7 +63,7 @@ def __repr__(self): OMapper = namedtuple('OMapper', 'core owned') -class HaloScheme: +class HaloScheme(HaloMixin): """ A HaloScheme describes a set of halo exchanges through a mapper: @@ -120,16 +121,6 @@ def __init__(self, exprs, ispace): self._honored[i.root] = frozenset([(ltk, rtk)]) self._honored = frozendict(self._honored) - def __repr__(self): - fstrings = [] - for f in self.fmapper: - loc_indices = OrderedSet(*(self._mapper[f].loc_indices.values())) - loc_indices_str = str(list(loc_indices)) if loc_indices else "" - fstrings.append("%s%s" % (f.name, loc_indices_str)) - - functions = ",".join(fstrings) - return "<%s(%s)>" % (self.__class__.__name__, functions) - def __eq__(self, other): return (isinstance(other, HaloScheme) and self._mapper == other._mapper and diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index a3cef2ff8b..fb1684a21a 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -76,18 +76,17 @@ def _hoist_invariant(iet): # Precompute scopes to save time scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()} - cond_mapper = _make_cond_mapper(iet) - # Analysis hsmapper = {} imapper = defaultdict(list) + cond_mapper = _make_cond_mapper(iet) iter_mapper = _filter_iter_mapper(iet) for it, halo_spots in iter_mapper.items(): for hs0, hs1 in combinations(halo_spots, r=2): - if _ensure_control_flow(hs0, hs1, cond_mapper): + if _check_control_flow(hs0, hs1, cond_mapper): continue # If there are overlapping loc_indices, skip @@ -110,13 +109,12 @@ def _hoist_invariant(iet): hse = hs1.halo_scheme.fmapper[f] raw_loc_indices = {} - for d in hse.loc_indices: - md = hse.loc_indices[d] - if md in it.uindices: - md_sub = it.start - raw_loc_indices[d] = md.symbolic_min.subs(it.dim, md_sub) + for d, v in hse.loc_indices.items(): + if v in it.uindices: + v_sub = it.start + raw_loc_indices[d] = v.symbolic_min.subs(it.dim, v_sub) else: - raw_loc_indices[d] = md + raw_loc_indices[d] = v hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices)) hs1.halo_scheme.fmapper[f] = hse @@ -151,10 +149,8 @@ def _merge_halospots(iet): """ # Analysis - cond_mapper = _make_cond_mapper(iet) - mapper = {} - + cond_mapper = _make_cond_mapper(iet) iter_mapper = _filter_iter_mapper(iet) for it, halo_spots in iter_mapper.items(): @@ -164,7 +160,7 @@ def _merge_halospots(iet): for hs1 in halo_spots[1:]: - if _ensure_control_flow(hs0, hs1, cond_mapper): + if _check_control_flow(hs0, hs1, cond_mapper): continue for f, v in hs1.fmapper.items(): @@ -370,7 +366,7 @@ def _make_cond_mapper(iet): for hs, v in cond_mapper.items()} -def _ensure_control_flow(hs0, hs1, cond_mapper): +def _check_control_flow(hs0, hs1, cond_mapper): """ If there are Conditionals involved, both `hs0` and `hs1` must be within the same Conditional, otherwise we would break control flow diff --git a/examples/seismic/viscoacoustic/operators.py b/examples/seismic/viscoacoustic/operators.py index a1119236e6..d237d43ea6 100755 --- a/examples/seismic/viscoacoustic/operators.py +++ b/examples/seismic/viscoacoustic/operators.py @@ -524,7 +524,7 @@ def ForwardOperator(model, geometry, space_order=4, kernel='sls', time_order=2, # Substitute spacing terms to reduce flops return Operator(eqn + src_term + rec_term, subs=model.spacing_map, - name='ViscoAcForward', **kwargs) + name='ViscoIsoAcousticForward', **kwargs) def AdjointOperator(model, geometry, space_order=4, kernel='SLS', time_order=2, **kwargs):