Skip to content

Commit

Permalink
compiler: Improve HaloSpot with reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 26, 2024
1 parent b6671da commit 819ba0f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 37 deletions.
24 changes: 13 additions & 11 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,8 +1438,20 @@ def DummyExpr(*args, init=False):

# Nodes required for distributed-memory halo exchange

class HaloMixin:

class HaloSpot(Node):
def __repr__(self):
fstrings = []
for f in self.fmapper.keys():
loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""
fstrings.append("%s%s" % (f.name, loc_indices_str))

functions = ",".join(fstrings)
return "<%s(%s)>" % (self.__class__.__name__, functions)


class HaloSpot(HaloMixin, Node):

"""
A halo exchange operation (e.g., send, recv, wait, ...) required to
Expand All @@ -1464,16 +1476,6 @@ def __init__(self, body, halo_scheme):

self._halo_scheme = halo_scheme

def __repr__(self):
fstrings = []
for f in self.functions:
loc_indices = OrderedSet(*(self.halo_scheme.fmapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""
fstrings.append("%s%s" % (f.name, loc_indices_str))

functions = ",".join(fstrings)
return "<%s(%s)>" % (self.__class__.__name__, functions)

@property
def halo_scheme(self):
return self._halo_scheme
Expand Down
13 changes: 2 additions & 11 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from devito import configuration
from devito.data import CORE, OWNED, LEFT, CENTER, RIGHT
from devito.ir.support import Forward, Scope
from devito.ir.iet.nodes import HaloMixin
from devito.symbolics.manipulation import _uxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer, filter_sorted, OrderedSet)
Expand Down Expand Up @@ -62,7 +63,7 @@ def __repr__(self):
OMapper = namedtuple('OMapper', 'core owned')


class HaloScheme:
class HaloScheme(HaloMixin):

"""
A HaloScheme describes a set of halo exchanges through a mapper:
Expand Down Expand Up @@ -120,16 +121,6 @@ def __init__(self, exprs, ispace):
self._honored[i.root] = frozenset([(ltk, rtk)])
self._honored = frozendict(self._honored)

def __repr__(self):
fstrings = []
for f in self.fmapper:
loc_indices = OrderedSet(*(self._mapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""
fstrings.append("%s%s" % (f.name, loc_indices_str))

functions = ",".join(fstrings)
return "<%s(%s)>" % (self.__class__.__name__, functions)

def __eq__(self, other):
return (isinstance(other, HaloScheme) and
self._mapper == other._mapper and
Expand Down
24 changes: 10 additions & 14 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,17 @@ def _hoist_invariant(iet):
# Precompute scopes to save time
scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()}

cond_mapper = _make_cond_mapper(iet)

# Analysis
hsmapper = {}
imapper = defaultdict(list)

cond_mapper = _make_cond_mapper(iet)
iter_mapper = _filter_iter_mapper(iet)

for it, halo_spots in iter_mapper.items():
for hs0, hs1 in combinations(halo_spots, r=2):

if _ensure_control_flow(hs0, hs1, cond_mapper):
if _check_control_flow(hs0, hs1, cond_mapper):
continue

# If there are overlapping loc_indices, skip
Expand All @@ -110,13 +109,12 @@ def _hoist_invariant(iet):
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}

for d in hse.loc_indices:
md = hse.loc_indices[d]
if md in it.uindices:
md_sub = it.start
raw_loc_indices[d] = md.symbolic_min.subs(it.dim, md_sub)
for d, v in hse.loc_indices.items():
if v in it.uindices:
v_sub = it.start
raw_loc_indices[d] = v.symbolic_min.subs(it.dim, v_sub)
else:
raw_loc_indices[d] = md
raw_loc_indices[d] = v

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse
Expand Down Expand Up @@ -151,10 +149,8 @@ def _merge_halospots(iet):
"""

# Analysis
cond_mapper = _make_cond_mapper(iet)

mapper = {}

cond_mapper = _make_cond_mapper(iet)
iter_mapper = _filter_iter_mapper(iet)

for it, halo_spots in iter_mapper.items():
Expand All @@ -164,7 +160,7 @@ def _merge_halospots(iet):

for hs1 in halo_spots[1:]:

if _ensure_control_flow(hs0, hs1, cond_mapper):
if _check_control_flow(hs0, hs1, cond_mapper):
continue

for f, v in hs1.fmapper.items():
Expand Down Expand Up @@ -370,7 +366,7 @@ def _make_cond_mapper(iet):
for hs, v in cond_mapper.items()}


def _ensure_control_flow(hs0, hs1, cond_mapper):
def _check_control_flow(hs0, hs1, cond_mapper):
"""
If there are Conditionals involved, both `hs0` and `hs1` must be
within the same Conditional, otherwise we would break control flow
Expand Down
2 changes: 1 addition & 1 deletion examples/seismic/viscoacoustic/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def ForwardOperator(model, geometry, space_order=4, kernel='sls', time_order=2,

# Substitute spacing terms to reduce flops
return Operator(eqn + src_term + rec_term, subs=model.spacing_map,
name='ViscoAcForward', **kwargs)
name='ViscoIsoAcousticForward', **kwargs)


def AdjointOperator(model, geometry, space_order=4, kernel='SLS', time_order=2, **kwargs):
Expand Down

0 comments on commit 819ba0f

Please sign in to comment.