diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index b45a7e099ca..3036a4e1307 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -206,7 +206,11 @@ def guard(clusters): k = d else: dims = pull_dims(cd.condition) - k = max(dims, default=d, key=lambda i: c.ispace.index(i)) + if (not dims.issubset(set(c.ispace.dimensions)) and + cd.parent in dims): + k = cd.parent + else: + k = max(dims, default=d, key=lambda i: c.ispace.index(i)) # Pull `cd` from any expr condition = guards.setdefault(k, []) diff --git a/tests/test_subdomains.py b/tests/test_subdomains.py index e0e9fc85b46..8b441475710 100644 --- a/tests/test_subdomains.py +++ b/tests/test_subdomains.py @@ -6,7 +6,7 @@ from conftest import opts_tiling, assert_structure from devito import (ConditionalDimension, Constant, Grid, Function, TimeFunction, - Eq, solve, Operator, SubDomain, SubDomainSet) + Eq, solve, Operator, SubDomain, SubDomainSet, Lt) from devito.ir import FindNodes, Expression, Iteration from devito.tools import timed_region @@ -628,3 +628,101 @@ class Dummy(SubDomainSet): assert x.is_Parallel assert y.is_Parallel assert z.is_Parallel + + +class TestSubDomain_w_condition(object): + + def test_condition_w_subdomain(self): + + shape = (10, ) + grid = Grid(shape=shape) + x, = grid.dimensions + + class Middle(SubDomain): + name = 'middle' + + def define(self, dimensions): + return {x: ('middle', 2, 4)} + + mid = Middle() + my_grid = Grid(shape=shape, subdomains=(mid, )) + + f = Function(name='f', grid=my_grid) + + sdf = Function(name='sdf', grid=my_grid) + sdf.data[5:] = 1 + + condition = Lt(sdf[mid.dimensions[0]], 1) + + ci = ConditionalDimension(name='ci', condition=condition, + parent=mid.dimensions[0]) + + op = Operator(Eq(f, f + 10, implicit_dims=ci, + subdomain=my_grid.subdomains['middle'])) + op.apply() + + assert_structure(op, ['i1x'], 'i1x') + + def test_condition_w_subdomain_II(self): + + shape = (10, 10) + grid = Grid(shape=shape) + x, y = grid.dimensions + + class Middle(SubDomain): + name = 'middle' + + def define(self, dimensions): + return {x: x, y: ('middle', 2, 4)} + + mid = Middle() + my_grid = Grid(shape=shape, subdomains=(mid, )) + + sdf = Function(name='sdf', grid=grid) + sdf.data[:, 5:] = 1 + sdf.data[2:6, 3:5] = 1 + + x1, y1 = mid.dimensions + + condition = Lt(sdf[x1, y1], 1) + ci = ConditionalDimension(name='ci', condition=condition, parent=y1) + + f = Function(name='f', grid=my_grid) + op = Operator(Eq(f, f + 10, implicit_dims=ci, + subdomain=my_grid.subdomains['middle'])) + + op.apply() + + assert_structure(op, ['xi1y'], 'xi1y') + + def test_condition_w_subdomain_III(self): + + shape = (10, 10) + grid = Grid(shape=shape) + x, y = grid.dimensions + + class Middle(SubDomain): + name = 'middle' + + def define(self, dimensions): + return {x: ('middle', 2, 4), y: ('middle', 2, 4)} + + mid = Middle() + my_grid = Grid(shape=shape, subdomains=(mid, )) + + sdf = Function(name='sdf', grid=my_grid) + sdf.data[2:4, 5:] = 1 + sdf.data[2:6, 3:5] = 1 + + x1, y1 = mid.dimensions + + condition = Lt(sdf[x1, y1], 1) + ci = ConditionalDimension(name='ci', condition=condition, parent=y1) + + f = Function(name='f', grid=my_grid) + op = Operator(Eq(f, f + 10, implicit_dims=ci, + subdomain=my_grid.subdomains['middle'])) + + op.apply() + + assert_structure(op, ['i1xi1y'], 'i1xi1y')