Skip to content

Commit

Permalink
api: make interp radius dimension always conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Aug 31, 2023
1 parent 2752814 commit ea78099
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 36 deletions.
2 changes: 1 addition & 1 deletion devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __new__(cls, dim, indices=None, expr=None, fd=None):

if fd is None:
try:
v = {d for d in expr.free_symbols if isinstance(d, StencilDimension)}
v = expr.atoms(StencilDimension)
assert len(v) == 1
fd = v.pop()
except AttributeError:
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __hash__(self):
def index_mode(self):
retval = []
for i, fi in zip(self, self.findices):
dims = {j for j in i.free_symbols if isinstance(j, Dimension)}
dims = i.atoms(Dimension)
if len(dims) == 0 and q_constant(i):
retval.append(AFFINE)
continue
Expand Down
6 changes: 3 additions & 3 deletions devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def detect_accesses(exprs):
mapper[f][a].update([0])

elif a.is_Add:
dims = {i for i in a.free_symbols if isinstance(i, Dimension)}
dims = a.atoms(Dimension)

if not dims:
continue
Expand Down Expand Up @@ -181,7 +181,7 @@ def detect_accesses(exprs):
# Compute M[None]
other_dims = set()
for e in as_tuple(exprs):
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
other_dims.update(e.atoms(Dimension))
other_dims.update(e.implicit_dims)
mapper[None] = Stencil([(i, 0) for i in other_dims])

Expand Down Expand Up @@ -262,7 +262,7 @@ def pull_dims(exprs, flag=True):
"""
dims = set()
for e in as_tuple(exprs):
dims.update({i for i in e.free_symbols if i.is_Dimension})
dims.update(e.atoms(Dimension))
if flag:
return set().union(*[d._defines for d in dims])
else:
Expand Down
23 changes: 12 additions & 11 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,19 @@ def r(self):
@cached_property
def _rdim(self):
parent = self.sfunction.dimensions[-1]
pos = self.sfunction._position_map.values()
dims = [CustomDimension("r%s%s" % (self.sfunction.name, d.name),
-self.r+1, self.r, 2*self.r, parent)
for d in self._gdims]
rdims = []
for (d, rd, p) in zip(self._gdims, dims, pos):
# Add conditional to avoid OOB
lb = sympy.And(rd + p >= d.symbolic_min - self.r, evaluate=False)
ub = sympy.And(rd + p <= d.symbolic_max + self.r, evaluate=False)
cond = sympy.And(lb, ub, evaluate=False)
rdims.append(ConditionalDimension(rd.name, rd, condition=cond, indirect=True))

return DimensionTuple(*dims, getters=self._gdims)
return DimensionTuple(*rdims, getters=self._gdims)

def _augment_implicit_dims(self, implicit_dims):
if self.sfunction._sparse_position == -1:
Expand All @@ -174,23 +182,16 @@ def _interp_idx(self, variables, implicit_dims=None):
"""
Generate interpolation indices for the DiscreteFunctions in ``variables``.
"""
mapper = {}
pos = self.sfunction._position_map.values()
# Temporaries for the position
temps = self._positions(implicit_dims)

# Coefficient symbol expression
temps.extend(self._coeff_temps(implicit_dims))
for ((di, d), rd, p) in zip(enumerate(self._gdims), self._rdim, pos):
# Add conditional to avoid OOB
lb = sympy.And(rd + p >= d.symbolic_min - self.r, evaluate=False)
ub = sympy.And(rd + p <= d.symbolic_max + self.r, evaluate=False)
cond = sympy.And(lb, ub, evaluate=False)
mapper[d] = ConditionalDimension(rd.name, rd, condition=cond, indirect=True)

# Substitution mapper for variables
idx_subs = {v: v.subs({k: c - v.origin.get(k, 0) + p
for ((k, c), p) in zip(mapper.items(), pos)})
idx_subs = {v: v.subs({d: r - v.origin.get(d, 0) + p
for (r, d, p) in zip(self._rdim, self._gdims, pos)})
for v in variables}

return idx_subs, temps
Expand Down Expand Up @@ -287,7 +288,7 @@ def _inject(self, field, expr, implicit_dims=None):
injection expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims) + self._rdim
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Make iterable to support inject((u, v), expr=expr)
# or inject((u, v), expr=(expr1, expr2))
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
split, timed_pass)
from devito.types import (Array, TempFunction, Eq, Symbol, Temp, ModuloDimension,
CustomDimension, IncrDimension, StencilDimension, Indexed,
Hyperplane)
Hyperplane, Dimension)
from devito.types.grid import MultiSubDimension

__all__ = ['cire']
Expand Down Expand Up @@ -1389,7 +1389,7 @@ def nredundants(ispace, expr):
a non-redundant iteration space (e.g., a BlockDimension).
"""
iterated = {i.dim for i in ispace}
used = {i for i in expr.free_symbols if i.is_Dimension}
used = expr.atoms(Dimension)

# "Short" dimensions won't count
key0 = lambda d: d.is_Sub and d.local
Expand Down
2 changes: 1 addition & 1 deletion tests/test_buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_over_injection():

# Check generated code
assert len(retrieve_iteration_tree(op1)) == \
7 + int(configuration['language'] != 'C')
8 + int(configuration['language'] != 'C')
buffers = [i for i in FindSymbols().visit(op1) if i.is_Array]
assert len(buffers) == 1

Expand Down
14 changes: 3 additions & 11 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def test_cache_blocking_structure_optrelax():

op = Operator(eqns, opt=('advanced', {'blockrelax': True}))

bns, _ = assert_blocking(op, {'x0_blk0', 'p_src0_blk0'})
bns, _ = assert_blocking(op, {'x0_blk0', 'p_src0_blk0', 'p_src1_blk0'})

iters = FindNodes(Iteration).visit(bns['p_src0_blk0'])
iters = FindNodes(Iteration).visit(bns['p_src1_blk0'])
assert len(iters) == 5
assert iters[0].dim.is_Block
assert iters[1].dim.is_Block
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_cache_blocking_structure_optrelax_prec_inject():
'openmp': True,
'par-collapse-ncores': 1}))

assert_structure(op, ['t', 't,p_s0_blk0,p_s,rsx,rsy'],
assert_structure(op, ['t', 't,p_s0_blk0,p_s', 't,p_s0_blk0,p_s,rsx,rsy'],
't,p_s0_blk0,p_s,rsx,rsy')


Expand Down Expand Up @@ -952,14 +952,6 @@ def test_parallel_prec_inject(self):
assert not iterations[0].pragmas
assert 'omp for' in iterations[1].pragmas[0].value

op0 = Operator(eqns, opt=('advanced', {'openmp': True,
'par-collapse-ncores': 1,
'par-collapse-work': 1}))
iterations = FindNodes(Iteration).visit(op0)

assert not iterations[0].pragmas
assert 'omp for collapse(2)' in iterations[1].pragmas[0].value


class TestNestedParallelism(object):

Expand Down
8 changes: 4 additions & 4 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def test_scheduling_after_rewrite():
trees = retrieve_iteration_tree(op)

# Check loop nest structure
assert all(i.dim is j for i, j in zip(trees[0], grid.dimensions)) # time invariant
assert trees[1].root.dim is grid.time_dim
assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:])
assert all(i.dim is j for i, j in zip(trees[1], grid.dimensions)) # time invariant
assert trees[2].root.dim is grid.time_dim
assert all(trees[2].root.dim is tree.root.dim for tree in trees[2:])


@pytest.mark.parametrize('exprs,expected,min_cost', [
Expand Down Expand Up @@ -1665,7 +1665,7 @@ def test_drop_redundants_after_fusion(self, rotate):
op = Operator(eqns, opt=('advanced', {'cire-rotate': rotate}))

arrays = [i for i in FindSymbols().visit(op) if i.is_Array]
assert len(arrays) == 2
assert len(arrays) == 4
assert all(i._mem_heap and not i._mem_external for i in arrays)

def test_full_shape_big_temporaries(self):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,7 +1803,8 @@ def test_scheduling_sparse_functions(self):
# `trees` than 6
op = Operator([eqn1] + eqn2 + [eqn3] + eqn4, opt=('noop', {'openmp': False}))
trees = retrieve_iteration_tree(op)
assert len(trees) == 5

assert len(trees) == 6
# Time loop not shared due to the WAR
assert trees[0][0].dim is time and trees[0][0] is trees[1][0] # this IS shared
assert trees[1][0] is not trees[3][0]
Expand All @@ -1813,7 +1814,7 @@ def test_scheduling_sparse_functions(self):
eqn2 = sf1.inject(u1.forward, expr=sf1)
op = Operator([eqn1] + eqn2 + [eqn3] + eqn4, opt=('noop', {'openmp': False}))
trees = retrieve_iteration_tree(op)
assert len(trees) == 5
assert len(trees) == 6
assert all(trees[0][0] is i[0] for i in trees)

def test_scheduling_with_free_dims(self):
Expand Down

0 comments on commit ea78099

Please sign in to comment.