Skip to content

Commit

Permalink
compiler: Fix placement of ConditionalDimension depending on cd.paren…
Browse files Browse the repository at this point in the history
…t and subdomain case
  • Loading branch information
georgebisbas committed Mar 17, 2023
1 parent 9061a67 commit 852545a
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
8 changes: 7 additions & 1 deletion devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,13 @@ def guard(clusters):
k = d
else:
dims = pull_dims(cd.condition)
k = max(dims, default=d, key=lambda i: c.ispace.index(i))
# If `cd` uses more dimensions than the ispace,
# stay under parent
if (not dims.issubset(set(c.ispace.dimensions)) and
cd.parent in dims):
k = cd.parent
else:
k = max(dims, default=d, key=lambda i: c.ispace.index(i))

# Pull `cd` from any expr
condition = guards.setdefault(k, [])
Expand Down
100 changes: 99 additions & 1 deletion tests/test_subdomains.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from conftest import opts_tiling, assert_structure
from devito import (ConditionalDimension, Constant, Grid, Function, TimeFunction,
Eq, solve, Operator, SubDomain, SubDomainSet)
Eq, solve, Operator, SubDomain, SubDomainSet, Lt)
from devito.ir import FindNodes, Expression, Iteration
from devito.tools import timed_region

Expand Down Expand Up @@ -628,3 +628,101 @@ class Dummy(SubDomainSet):
assert x.is_Parallel
assert y.is_Parallel
assert z.is_Parallel


class TestSubDomain_w_condition(object):

def test_condition_w_subdomain(self):

shape = (10, )
grid = Grid(shape=shape)
x, = grid.dimensions

class Middle(SubDomain):
name = 'middle'

def define(self, dimensions):
return {x: ('middle', 2, 4)}

mid = Middle()
my_grid = Grid(shape=shape, subdomains=(mid, ))

f = Function(name='f', grid=my_grid)

sdf = Function(name='sdf', grid=my_grid)
sdf.data[5:] = 1

condition = Lt(sdf[mid.dimensions[0]], 1)

ci = ConditionalDimension(name='ci', condition=condition,
parent=mid.dimensions[0])

op = Operator(Eq(f, f + 10, implicit_dims=ci,
subdomain=my_grid.subdomains['middle']))
op.apply()

assert_structure(op, ['i1x'], 'i1x')

def test_condition_w_subdomain_II(self):

shape = (10, 10)
grid = Grid(shape=shape)
x, y = grid.dimensions

class Middle(SubDomain):
name = 'middle'

def define(self, dimensions):
return {x: x, y: ('middle', 2, 4)}

mid = Middle()
my_grid = Grid(shape=shape, subdomains=(mid, ))

sdf = Function(name='sdf', grid=grid)
sdf.data[:, 5:] = 1
sdf.data[2:6, 3:5] = 1

x1, y1 = mid.dimensions

condition = Lt(sdf[x1, y1], 1)
ci = ConditionalDimension(name='ci', condition=condition, parent=y1)

f = Function(name='f', grid=my_grid)
op = Operator(Eq(f, f + 10, implicit_dims=ci,
subdomain=my_grid.subdomains['middle']))

op.apply()

assert_structure(op, ['xi1y'], 'xi1y')

def test_condition_w_subdomain_III(self):

shape = (10, 10)
grid = Grid(shape=shape)
x, y = grid.dimensions

class Middle(SubDomain):
name = 'middle'

def define(self, dimensions):
return {x: ('middle', 2, 4), y: ('middle', 2, 4)}

mid = Middle()
my_grid = Grid(shape=shape, subdomains=(mid, ))

sdf = Function(name='sdf', grid=my_grid)
sdf.data[2:4, 5:] = 1
sdf.data[2:6, 3:5] = 1

x1, y1 = mid.dimensions

condition = Lt(sdf[x1, y1], 1)
ci = ConditionalDimension(name='ci', condition=condition, parent=y1)

f = Function(name='f', grid=my_grid)
op = Operator(Eq(f, f + 10, implicit_dims=ci,
subdomain=my_grid.subdomains['middle']))

op.apply()

assert_structure(op, ['i1xi1y'], 'i1xi1y')

0 comments on commit 852545a

Please sign in to comment.