From 90ad3e415aa53217baf82932e479465c96d5d653 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 28 Jun 2024 08:52:04 +0000 Subject: [PATCH] compiler: Patch CSE in presence of conditionals --- devito/passes/clusters/cse.py | 14 ++++++++------ tests/test_dse.py | 25 ++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 5e4ce40d36..97edce538e 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -98,24 +98,26 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'): # Create temporaries hit = max(targets.values()) - temps = [Eq(make(), k) for k, v in targets.items() if v == hit] + chosen = [(k, make()) for k, v in targets.items() if v == hit] # Apply replacements # The extracted temporaries are inserted before the first expression # that contains it + scheduled = [] updated = [] for e in processed: pe = e - for t in temps: - pe, changed = _uxreplace(pe, {t.rhs: t.lhs}) - if changed and t not in updated: - updated.append(t) + for k, v in chosen: + pe, changed = _uxreplace(pe, {k: v}) + if changed and v not in scheduled: + updated.append(pe.func(v, k, operation=None)) + scheduled.append(v) updated.append(pe) processed = updated # Update `exclude` for the same reasons as above -- to rule out CSE across # Dimension-independent data dependences - exclude.update({t.lhs for t in temps}) + exclude.update(scheduled) # At this point we may have useless temporaries (e.g., r0=r1). Let's drop them processed = _compact_temporaries(processed, exclude) diff --git a/tests/test_dse.py b/tests/test_dse.py index 2a52b443a7..1d7071cabc 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -13,7 +13,7 @@ ConditionalDimension, DefaultDimension, Grid, Operator, norm, grad, div, dimensions, switchconfig, configuration, centered, first_derivative, solve, transpose, Abs, cos, - sin, sqrt) + sin, sqrt, Ge) from devito.exceptions import InvalidArgument, InvalidOperator from devito.finite_differences.differentiable import diffify from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, @@ -168,6 +168,29 @@ def test_cse_temp_order(): assert type(args[2]) is CTemp +def test_cse_w_conditionals(): + grid = Grid(shape=(10, 10, 10)) + x, _, _ = grid.dimensions + + cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4), + indirect=True) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = Function(name='h', grid=grid) + a0 = Function(name='a0', grid=grid) + a1 = Function(name='a1', grid=grid) + + eqns = [Eq(h, a0, implicit_dims=cd), + Eq(a0, a0 + f*g, implicit_dims=cd), + Eq(a1, a1 + f*g, implicit_dims=cd)] + + op = Operator(eqns) + + assert_structure(op, ['x,y,z'], 'xyz') + assert len(FindNodes(Conditional).visit(op)) == 1 + + @pytest.mark.parametrize('expr,expected', [ ('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'), ('fa[x]**2', 'fa[x]*fa[x]'),