diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index eac09c0f8d..af8eee97cf 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -238,7 +238,7 @@ def guard(clusters): if cd._factor is not None: k = d else: - dims = pull_dims(cd.condition) + dims = pull_dims(cd.condition, flag=False) k = max(dims, default=d, key=lambda i: c.ispace.index(i)) # Pull `cd` from any expr diff --git a/tests/test_subdomains.py b/tests/test_subdomains.py index 839cb2299d..c5a8ee2faa 100644 --- a/tests/test_subdomains.py +++ b/tests/test_subdomains.py @@ -6,7 +6,7 @@ from conftest import opts_tiling, assert_structure from devito import (ConditionalDimension, Constant, Grid, Function, TimeFunction, - Eq, solve, Operator, SubDomain, SubDomainSet) + Eq, solve, Operator, SubDomain, SubDomainSet, Lt) from devito.ir import FindNodes, Expression, Iteration from devito.tools import timed_region @@ -693,3 +693,101 @@ class Dummy(SubDomainSet): # Switch the thickness symbols between MultiSubDimensions with the rebuild remixed = [d._rebuild(thickness=t) for d, t in zip(sdims, tkns[::-1])] assert [d.thickness for d in remixed] == tkns[::-1] + + +class TestSubDomain_w_condition(object): + + def test_condition_w_subdomain_v0(self): + + shape = (10, ) + grid = Grid(shape=shape) + x, = grid.dimensions + + class Middle(SubDomain): + name = 'middle' + + def define(self, dimensions): + return {x: ('middle', 2, 4)} + + mid = Middle() + my_grid = Grid(shape=shape, subdomains=(mid, )) + + f = Function(name='f', grid=my_grid) + + sdf = Function(name='sdf', grid=my_grid) + sdf.data[5:] = 1 + + condition = Lt(sdf[mid.dimensions[0]], 1) + + ci = ConditionalDimension(name='ci', condition=condition, + parent=mid.dimensions[0]) + + op = Operator(Eq(f, f + 10, implicit_dims=ci, + subdomain=my_grid.subdomains['middle'])) + op.apply() + + assert_structure(op, ['x'], 'x') + + def test_condition_w_subdomain_v1(self): + + shape = (10, 10) + grid = Grid(shape=shape) + x, y = grid.dimensions + + class Middle(SubDomain): + name = 'middle' + + def define(self, dimensions): + return {x: x, y: ('middle', 2, 4)} + + mid = Middle() + my_grid = Grid(shape=shape, subdomains=(mid, )) + + sdf = Function(name='sdf', grid=grid) + sdf.data[:, 5:] = 1 + sdf.data[2:6, 3:5] = 1 + + x1, y1 = mid.dimensions + + condition = Lt(sdf[x1, y1], 1) + ci = ConditionalDimension(name='ci', condition=condition, parent=y1) + + f = Function(name='f', grid=my_grid) + op = Operator(Eq(f, f + 10, implicit_dims=ci, + subdomain=my_grid.subdomains['middle'])) + + op.apply() + + assert_structure(op, ['xy'], 'xy') + + def test_condition_w_subdomain_v2(self): + + shape = (10, 10) + grid = Grid(shape=shape) + x, y = grid.dimensions + + class Middle(SubDomain): + name = 'middle' + + def define(self, dimensions): + return {x: ('middle', 2, 4), y: ('middle', 2, 4)} + + mid = Middle() + my_grid = Grid(shape=shape, subdomains=(mid, )) + + sdf = Function(name='sdf', grid=my_grid) + sdf.data[2:4, 5:] = 1 + sdf.data[2:6, 3:5] = 1 + + x1, y1 = mid.dimensions + + condition = Lt(sdf[x1, y1], 1) + ci = ConditionalDimension(name='ci', condition=condition, parent=y1) + + f = Function(name='f', grid=my_grid) + op = Operator(Eq(f, f + 10, implicit_dims=ci, + subdomain=my_grid.subdomains['middle'])) + + op.apply() + + assert_structure(op, ['xy'], 'xy')