Skip to content

Commit

Permalink
Merge pull request #2451 from devitocodes/halo-buff1
Browse files Browse the repository at this point in the history
compiler: Fix handling of modulo 1 for MPI
  • Loading branch information
mloubout authored Sep 10, 2024
2 parents ca0bcd0 + eab3579 commit 63dcfb1
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 26 deletions.
2 changes: 1 addition & 1 deletion devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def reinit_compiler(val):
deprecate='openmp')

# MPI mode (0 => disabled, 1 == basic)
preprocessor = lambda i: bool(i) if isinstance(i, int) else i
preprocessor = lambda i: {0: False, 1: 'basic'}.get(i, i)
configuration.add('mpi', 0, [0, 1] + list(mpi_registry),
preprocessor=preprocessor, callback=reinit_compiler)

Expand Down
8 changes: 2 additions & 6 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,8 @@ def rule(size, e):
groups = as_mapper(mds, lambda d: d.modulo)
for size, v in groups.items():
key = partial(rule, size)
if size == 1:
# Optimization -- avoid useless "% 1" ModuloDimensions
subs = {md.origin: 0 for md in v}
else:
subs = {md.origin: md for md in v}
sub_iterators[d].extend(v)
subs = {md.origin: md for md in v}
sub_iterators[d].extend(v)

func = partial(xreplace_indices, mapper=subs, key=key)
exprs = [e.apply(func) for e in exprs]
Expand Down
1 change: 0 additions & 1 deletion devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,6 @@ def _call_poke(self, poke):


mpi_registry = {
True: BasicHaloExchangeBuilder,
'basic': BasicHaloExchangeBuilder,
'diag': DiagHaloExchangeBuilder,
'diag2': Diag2HaloExchangeBuilder,
Expand Down
36 changes: 21 additions & 15 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,25 +246,31 @@ def remove_redundant_moddims(iet):
if not mds:
return iet

mapper = as_mapper(mds, key=lambda md: md.offset % md.modulo)
# ModuloDimensions are defined in Iteration headers, hence they must be
# removed from there first of all
mapper = {}
for n in FindNodes(Iteration).visit(iet):
candidates = [d for d in n.uindices if d in mds]

subs = {}
for k, v in mapper.items():
chosen = v.pop(0)
subs.update({d: chosen for d in v})
degenerates, others = split(candidates, lambda d: d.modulo == 1)
subs = {d: sympy.S.Zero for d in degenerates}

body = Uxreplace(subs).visit(iet.body)
iet = iet._rebuild(body=body)
redundants = as_mapper(others, key=lambda d: d.offset % d.modulo)
for k, v in redundants.items():
chosen = v.pop(0)
subs.update({d: chosen for d in v})

# ModuloDimensions are defined in Iteration headers, hence they must be
# removed from there too
subs = {}
for n in FindNodes(Iteration).visit(iet):
if not set(n.uindices) & set(mds):
continue
subs[n] = n._rebuild(uindices=filter_ordered(n.uindices))
if subs:
# Expunge the ModuloDimensions from the Iteration header
uindices = [d for d in n.uindices if d not in subs]
iteration = n._rebuild(uindices=uindices)

iet = Transformer(subs, nested=True).visit(iet)
# Replace the ModuloDimensions in the Iteration body
iteration = Uxreplace(subs).visit(iteration)

mapper[n] = iteration

iet = Transformer(mapper, nested=True).visit(iet)

return iet

Expand Down
2 changes: 1 addition & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_reduction_many_dims(self):
op1 = Operator(eqns, opt=('advanced', {'mapify-reduce': True}))

tree, = retrieve_iteration_tree(op0)
assert 'collapse(4) reduction(+:s)' in str(tree.root.pragmas[0])
assert 'collapse(3) reduction(+:s)' in str(tree[1].pragmas[0])

tree, = retrieve_iteration_tree(op1)
assert 'collapse(3) reduction(+:s)' in str(tree[1].pragmas[0])
Expand Down
26 changes: 24 additions & 2 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from conftest import _R, assert_blocking, assert_structure
from devito import (Grid, Constant, Function, TimeFunction, SparseFunction,
SparseTimeFunction, Dimension, ConditionalDimension,
SparseTimeFunction, Dimension, ConditionalDimension, div,
SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm,
inner, configuration, switchconfig, generic_derivative,
PrecomputedSparseFunction, DefaultDimension)
PrecomputedSparseFunction, DefaultDimension, Buffer)
from devito.arch.compiler import OneapiCompiler
from devito.data import LEFT, RIGHT
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
Expand Down Expand Up @@ -1639,6 +1639,28 @@ def test_enforce_haloupdate_if_unwritten_function(self, mode):
calls = FindNodes(Call).visit(op)
assert len(calls) == 2 # One for `v` and one for `usave`

@pytest.mark.parallel(mode=1)
def test_haloupdate_buffer1(self, mode):
grid = Grid(shape=(4, 4))
x, y = grid.dimensions

u = TimeFunction(name='u', grid=grid, time_order=1, save=Buffer(1))
v = TimeFunction(name='v', grid=grid, time_order=1, save=Buffer(1))

eqns = [Eq(u.forward, div(v) + 1.),
Eq(v.forward, div(u.forward) + 1.)]

op = Operator(eqns)

calls = FindNodes(Call).visit(op)
# There should be two separate calls
# halo(v), eq_u, halo_u, eq(v)
assert len(calls) == 2

# Also ensure the compiler is doing its job removing unnecessary
# ModuloDimensions
assert len([i for i in FindSymbols('dimensions').visit(op) if i.is_Modulo]) == 0


class TestOperatorAdvanced:

Expand Down

0 comments on commit 63dcfb1

Please sign in to comment.