diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 23f293ad1e..4aa3df1f2b 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1437,7 +1437,6 @@ def DummyExpr(*args, init=False): # Nodes required for distributed-memory halo exchange - class HaloSpot(Node): """ @@ -1463,6 +1462,10 @@ def __init__(self, body, halo_scheme): self._halo_scheme = halo_scheme + def __repr__(self): + functions = "(%s)" % ",".join(i.name for i in self.functions) + return "<%s%s>" % (self.__class__.__name__, functions) + @property def halo_scheme(self): return self._halo_scheme @@ -1495,10 +1498,6 @@ def body(self): def functions(self): return tuple(self.fmapper) - def __repr__(self): - funcs = self.halo_scheme.__reprfuncs__() - return "<%s(%s)>" % (self.__class__.__name__, funcs) - # Utility classes diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index d80127fae2..3819caccb4 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -11,7 +11,7 @@ from devito.ir.support import Forward, Scope from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten, - frozendict, is_integer, filter_sorted, OrderedSet) + frozendict, is_integer, filter_sorted) from devito.types import Grid __all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch'] @@ -62,7 +62,7 @@ def __repr__(self): OMapper = namedtuple('OMapper', 'core owned') -class HaloScheme(): +class HaloScheme: """ A HaloScheme describes a set of halo exchanges through a mapper: @@ -120,17 +120,9 @@ def __init__(self, exprs, ispace): self._honored[i.root] = frozenset([(ltk, rtk)]) self._honored = frozendict(self._honored) - def __reprfuncs__(self): - fstrings = [] - for f in self.fmapper.keys(): - loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values())) - loc_indices_str = str(list(loc_indices)) if loc_indices else "" - fstrings.append("%s%s" % (f.name, loc_indices_str)) - - return ",".join(fstrings) - def __repr__(self): - return "<%s(%s)>" % (self.__class__.__name__, self.__reprfuncs__()) + fnames = ",".join(i.name for i in set(self._mapper)) + return "HaloScheme<%s>" % fnames def __eq__(self, other): return (isinstance(other, HaloScheme) and @@ -545,8 +537,6 @@ def classify(exprs, ispace): loc_indices, loc_dirs = process_loc_indices(raw_loc_indices, ispace.directions) - halos = frozenset(halos) - dims = frozenset(dims) mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims) diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 030dfacfad..3bc6a61344 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -101,19 +101,18 @@ def _hoist_invariant(iet): if not any(r(dep, hs1, v.loc_indices) for r in rules): break else: - # hs1 can be hoisted out of `it`, but we need to infer valid + # `hs1`` can be hoisted out of `it`, but we need to infer valid # loc_indices hse = hs1.halo_scheme.fmapper[f] - raw_loc_indices = {} + loc_indices = {} for d, v in hse.loc_indices.items(): if v in it.uindices: - v_sub = it.start - raw_loc_indices[d] = v.symbolic_min.subs(it.dim, v_sub) + loc_indices[d] = v.symbolic_min.subs(it.dim, it.start) else: - raw_loc_indices[d] = v + loc_indices[d] = v - hse = hse._rebuild(loc_indices=raw_loc_indices) + hse = hse._rebuild(loc_indices=loc_indices) hs1.halo_scheme.fmapper[f] = hse hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) @@ -357,14 +356,11 @@ def _filter_iter_mapper(iet): def _make_cond_mapper(iet): - + "Return a mapper from HaloSpots to the Conditionals that contain them." cond_mapper = {} for hs, v in MapHaloSpots().visit(iet).items(): - conditionals = set() - for i in v: - if i.is_Conditional and not isinstance(i.condition, GuardFactorEq): - conditionals.add(i) - + conditionals = {i for i in v if i.is_Conditional and + not isinstance(i.condition, GuardFactorEq)} cond_mapper[hs] = conditionals return cond_mapper diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 05a5b4b82d..9d2032c472 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -1299,7 +1299,7 @@ def test_unmerge_haloupdate_if_no_locindices(self, mode): assert np.allclose(g.data_ro_domain[0, 5:], [16., 16., 14., 13., 6.], rtol=R) @pytest.mark.parallel(mode=1) - def test_merge_haloupdate_if_diff_locindices_v0(self, mode): + def test_merge_haloupdate_if_diff_locindices(self, mode): grid = Grid(shape=(101, 101)) x, y = grid.dimensions t = grid.stepping_dim @@ -1320,11 +1320,12 @@ def test_merge_haloupdate_if_diff_locindices_v0(self, mode): op.cfunction @pytest.mark.parallel(mode=2) - def test_merge_haloupdate_if_diff_locindices_v1(self, mode): + def test_merge_and_hoist_haloupdate_if_diff_locindices(self, mode): """ This test is a revisited, more complex version of - `test_merge_haloupdate_if_diff_locindices_v0`. And in addition to - checking the generated code, it also checks the numerical output. + `test_merge_haloupdate_if_diff_locindices`, also checking hoisting. + And in addition to checking the generated code, + it also checks the numerical output. In the Operator there are three Eqs: