From 45678853c0f5ca7fbab4a682bb42ef030c6d5157 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Wed, 27 Nov 2024 14:17:47 +0200 Subject: [PATCH] compiler: Further split tests --- devito/mpi/halo_scheme.py | 2 +- devito/passes/iet/mpi.py | 6 +- tests/test_mpi.py | 118 ++++++++++++++++++++------------------ 3 files changed, 66 insertions(+), 60 deletions(-) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 92bfa92d03..f1bbb79003 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -12,7 +12,7 @@ from devito.ir.iet.nodes import HaloMixin from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten, - frozendict, is_integer, filter_sorted, OrderedSet) + frozendict, is_integer, filter_sorted) from devito.types import Grid __all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch'] diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index fb1684a21a..204fcc2cae 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -34,10 +34,8 @@ def optimize_halospots(iet, **kwargs): def _drop_reduction_halospots(iet): """ - Remove HaloSpots that: - - * Would be used to compute Increments (in which case, a halo exchange - is actually unnecessary) + Remove HaloSpots that are used to compute Increments + (in which case, a halo exchange is actually unnecessary) """ mapper = defaultdict(set) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 53b34b883f..7d28dfe7bc 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -1008,16 +1008,14 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode): calls = FindNodes(Call).visit(op) assert len(calls) == 0 - @pytest.mark.parallel(mode=1) - def test_issue_2448(self, mode): + @pytest.fixture + def setup(self): shape = (2,) so = 2 + tn = 30 grid = Grid(shape=shape) - # Time related - tn = 30 - # Velocity and pressure fields v = TimeFunction(name='v', grid=grid, space_order=so) tau = TimeFunction(name='tau', grid=grid, space_order=so) @@ -1026,86 +1024,76 @@ def test_issue_2448(self, mode): pde_v = v.dt - (tau.dx) pde_tau = tau.dt - ((v.forward).dx) u_v = Eq(v.forward, solve(pde_v, v.forward)) - u_tau = Eq(tau.forward, solve(pde_tau, tau.forward)) - # Test two variants of receiver interpolation + # Receiver rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn) rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1) - # The receiver 0 - rec_term0 = rec.interpolate(expr=v) + return grid, v, tau, u_v, u_tau, rec - # The receiver 1 - rec_term1 = rec.interpolate(expr=v.forward) + @pytest.mark.parallel(mode=1) + def test_issue_2448_I(self, mode, setup): + _, v, tau, u_v, u_tau, rec = setup - # Test receiver interpolation 0, here we have a halo exchange hoisted - op0 = Operator([u_v] + [u_tau] + rec_term0) + rec_term0 = rec.interpolate(expr=v) - calls = [i for i in FindNodes(Call).visit(op0) - if isinstance(i, HaloUpdateCall)] + op0 = Operator([u_v, u_tau, rec_term0]) - # The correct we want - assert len(calls) == 3 + calls = [i for i in FindNodes(Call).visit(op0) if isinstance(i, HaloUpdateCall)] + assert len(calls) == 3 assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 assert calls[0].arguments[0] is v assert calls[1].arguments[0] is tau assert calls[2].arguments[0] is v - # Test receiver interpolation 1, here we should not have any halo exchange - # hoisted - op1 = Operator([u_v] + [u_tau] + rec_term1) + @pytest.mark.parallel(mode=1) + def test_issue_2448_II(self, mode, setup): + _, v, tau, u_v, u_tau, rec = setup - calls = [i for i in FindNodes(Call).visit(op1) - if isinstance(i, HaloUpdateCall)] + rec_term1 = rec.interpolate(expr=v.forward) - # The correct we want - assert len(calls) == 3 + op1 = Operator([u_v, u_tau, rec_term1]) + calls = [i for i in FindNodes(Call).visit(op1) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 3 assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[0])) == 0 assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[1])) == 3 assert calls[0].arguments[0] is tau assert calls[1].arguments[0] is v assert calls[2].arguments[0] is v - # Further complicate/stree-test adding an artifical example - # with two hoisting opportunities + @pytest.mark.parallel(mode=1) + def test_issue_2448_III(self, mode, setup): + grid, v, tau, u_v, u_tau, rec = setup - # Velocity and pressure fields - v2 = TimeFunction(name='v2', grid=grid, space_order=so) - tau2 = TimeFunction(name='tau2', grid=grid, space_order=so) + # Additional velocity and pressure fields + v2 = TimeFunction(name='v2', grid=grid, space_order=2) + tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) # First order elastic-like dependencies equations pde_v2 = v2.dt - (tau2.dx) pde_tau2 = tau2.dt - ((v2.forward).dx) u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) - u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) - # Test two variants of receiver interpolation - rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=tn) - rec2.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1) + # Receiver + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) + rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) - # The receiver 2 + rec_term0 = rec.interpolate(expr=v) rec_term2 = rec2.interpolate(expr=v2) - # The receiver 3 - rec_term3 = rec2.interpolate(expr=v2.forward) - - # Test receiver interpolation 0, here we have a halo exchange hoisted - op2 = Operator([u_v] + [u_v2] + [u_tau] + [u_tau2] + rec_term0 + rec_term2) + op2 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term2]) - calls = [i for i in FindNodes(Call).visit(op2) - if isinstance(i, HaloUpdateCall)] + calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)] - # The correct we want assert len(calls) == 5 - assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2 assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3 - assert calls[0].arguments[0] is v assert calls[1].arguments[0] is v2 assert calls[2].arguments[0] is tau @@ -1113,18 +1101,34 @@ def test_issue_2448(self, mode): assert calls[3].arguments[0] is v assert calls[4].arguments[0] is v2 - # Test receiver interpolation 0, here we have a halo exchange hoisted - op3 = Operator([u_v] + [u_v2] + [u_tau] + [u_tau2] + rec_term0 + rec_term3) + @pytest.mark.parallel(mode=1) + def test_issue_2448_IV(self, mode, setup): + grid, v, tau, u_v, u_tau, rec = setup - calls = [i for i in FindNodes(Call).visit(op3) - if isinstance(i, HaloUpdateCall)] + # Additional velocity and pressure fields + v2 = TimeFunction(name='v2', grid=grid, space_order=2) + tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) - # The correct we want - assert len(calls) == 5 + # First order elastic-like dependencies equations + pde_v2 = v2.dt - (tau2.dx) + pde_tau2 = tau2.dt - ((v2.forward).dx) + u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) + u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) + + # Receiver + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) + rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) + + rec_term0 = rec.interpolate(expr=v) + rec_term3 = rec2.interpolate(expr=v2.forward) + op3 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term3]) + + calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 5 assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[0])) == 1 assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[1])) == 4 - assert calls[0].arguments[0] is v assert calls[1].arguments[0] is tau assert calls[1].arguments[1] is tau2 @@ -1136,8 +1140,7 @@ def test_issue_2448(self, mode): def test_issue_2448_backward(self, mode): ''' Similar to test_issue_2448, but with backward instead of forward - so that the hoisted halo - + so that the hoisted halo has different starting point ''' shape = (2,) so = 2 @@ -1397,7 +1400,7 @@ def test_avoid_fullmode_if_crossloop_dep(self, mode): assert np.all(f.data[:] == 2.) @pytest.mark.parallel(mode=2) - def test_avoid_haloudate_if_flowdep_along_other_dim(self, mode): + def test_avoid_halopudate_if_flowdep_along_other_dim(self, mode): grid = Grid(shape=(10,)) x = grid.dimensions[0] t = grid.stepping_dim @@ -1535,6 +1538,11 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode): calls = FindNodes(Call).visit(op) assert len(calls) == 2 + assert calls[0].arguments[3].args[0] is t.symbolic_min + + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[2])) == 0 op.apply(time_M=1) glb_pos_map = f.grid.distributor.glb_pos_map @@ -2953,7 +2961,7 @@ def test_elastic_structure(self, mode): assert calls[4].arguments[1] is v[1] -class TestTTIwMPI: +class TestTTIOp: @pytest.mark.parallel(mode=1) def test_halo_structure(self, mode):