Skip to content

Commit

Permalink
compile: Improve codegen aesthetics of cire-rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Oct 21, 2024
1 parent bac17a1 commit ca308e8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 20 deletions.
33 changes: 30 additions & 3 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import sympy

from devito.finite_differences import Max, Min
from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes,
FindSymbols, Transformer, Uxreplace, filter_iterations,
retrieve_iteration_tree, pull_dims)
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
FindApplications, FindNodes, FindSymbols, Transformer,
Uxreplace, filter_iterations, retrieve_iteration_tree,
pull_dims)
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper,
Expand Down Expand Up @@ -231,10 +232,13 @@ def minimize_symbols(iet):
* Remove redundant ModuloDimensions (e.g., due to using the
`save=Buffer(2)` API)
* Simplify Iteration headers (e.g., ModuloDimensions with identical
starting point and step)
* Abridge SubDimension names where possible to declutter generated
loop nests and shrink indices
"""
iet = remove_redundant_moddims(iet)
iet = simplify_iteration_headers(iet)
iet = abridge_dim_names(iet)

return iet, {}
Expand Down Expand Up @@ -264,6 +268,29 @@ def remove_redundant_moddims(iet):
return iet


def simplify_iteration_headers(iet):
mapper = {}
for i in FindNodes(Iteration).visit(iet):
candidates = [d for d in i.uindices
if d.is_Modulo and d.symbolic_min == d.symbolic_incr]

# Don't touch `t0, t1, ...` for codegen aesthetics and to avoid
# massive changes in the test suite
candidates = [d for d in candidates if not d.is_Time]

if not candidates:
continue

uindices = [d for d in i.uindices if d not in candidates]
stmts = [DummyExpr(d, d.symbolic_incr, init=True) for d in candidates]

mapper[i] = i._rebuild(nodes=tuple(stmts) + i.nodes, uindices=uindices)

iet = Transformer(mapper, nested=True).visit(iet)

return iet


@singledispatch
def abridge_dim_names(iet):
return iet
Expand Down
16 changes: 8 additions & 8 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,17 +1319,17 @@ def test_multiple_subnests_v1(self):
bns, _ = assert_blocking(op, {'x0_blk0'})

trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2
assert len(trees) == 4

assert trees[0][0] is trees[1][0]
assert trees[0][0].pragmas[0].ccode.value ==\
assert len(set(i.root for i in trees)) == 1
assert trees[-2].root.pragmas[0].ccode.value ==\
'omp for collapse(2) schedule(dynamic,1)'
assert not trees[0][2].pragmas
assert not trees[0][3].pragmas
assert trees[0][4].pragmas[0].ccode.value ==\
assert not trees[-2][2].pragmas
assert not trees[-2][3].pragmas
assert trees[-2][4].pragmas[0].ccode.value ==\
'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)'
assert not trees[1][2].pragmas
assert trees[1][3].pragmas[0].ccode.value ==\
assert not trees[-1][2].pragmas
assert trees[-1][3].pragmas[0].ccode.value ==\
'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)'

@pytest.mark.parametrize('blocklevels', [1, 2])
Expand Down
25 changes: 16 additions & 9 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,13 +1131,13 @@ def test_from_different_nests(self, rotate):
# Check code generation
bns, _ = assert_blocking(op1, {'x0_blk0', 'x1_blk0'})
trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2
assert trees[0][-1].nodes[0].body[0].write.is_Array
assert trees[1][-1].nodes[0].body[0].write is u
assert len(trees) == 4 if rotate else 2
assert trees[-2][-1].nodes[0].body[0].write.is_Array
assert trees[-1][-1].nodes[0].body[0].write is u
trees = retrieve_iteration_tree(bns['x1_blk0'])
assert len(trees) == 2
assert trees[0][-1].nodes[0].body[0].write.is_Array
assert trees[1][-1].nodes[0].body[0].write is v
assert len(trees) == 4 if rotate else 2
assert trees[-2][-1].nodes[0].body[0].write.is_Array
assert trees[-1][-1].nodes[0].body[0].write is v

# Check numerical output
op0(time_M=1)
Expand Down Expand Up @@ -2093,9 +2093,12 @@ def test_maxpar_option(self, rotate):
# Check code generation
bns, _ = assert_blocking(op1, {'x0_blk0'})
trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2
if rotate:
assert len(trees) == 5
else:
assert len(trees) == 2
assert trees[0][2] is not trees[1][2]
assert trees[0][1] is trees[1][1]
assert trees[0][2] is not trees[1][2]

# Check numerical output
op0.apply(time_M=2)
Expand Down Expand Up @@ -2191,7 +2194,11 @@ def test_blocking_options(self, rotate):
if rotate:
assert_structure(
op1,
prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z',
prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc,z',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,z'],
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z,y,yc,z,z'
Expand Down

0 comments on commit ca308e8

Please sign in to comment.