Skip to content

Commit

Permalink
Merge pull request #2208 from devitocodes/revamp-opt-derivs-final
Browse files Browse the repository at this point in the history
compiler: Revamp lowering of IndexDerivatives
  • Loading branch information
FabioLuporini authored Oct 16, 2023
2 parents b6f7308 + c1ebe2f commit fb32972
Show file tree
Hide file tree
Showing 59 changed files with 2,030 additions and 891 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def parallel(item):
# OpenMPI requires an explicit flag for oversubscription. We need it as some
# of the MPI tests will spawn lots of processes
if mpi_distro == 'OpenMPI':
call = [mpi_exec, '--oversubscribe', '--timeout', '150'] + args
call = [mpi_exec, '--oversubscribe', '--timeout', '300'] + args
else:
call = [mpi_exec] + args

Expand Down
15 changes: 8 additions & 7 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ class OptOption(object):

class ParTileArg(tuple):

def __new__(cls, items, shm=0, tag=None):
def __new__(cls, items, rule=None, tag=None):
obj = super().__new__(cls, items)
obj.shm = shm
obj.rule = rule
obj.tag = tag
return obj

Expand Down Expand Up @@ -371,14 +371,15 @@ def __new__(cls, items, default=None):

try:
y = items[1]
if is_integer(y):
# E.g., ((32, 4, 8), 1)
# E.g., ((32, 4, 8), 1, 'tag')
if is_integer(y) or isinstance(y, str) or y is None:
# E.g., ((32, 4, 8), 'rule')
# E.g., ((32, 4, 8), 'rule', 'tag')
items = (ParTileArg(*items),)
else:
try:
# E.g., (((32, 4, 8), 1), ((32, 4, 4), 2))
# E.g., (((32, 4, 8), 1, 'tag0'), ((32, 4, 4), 2, 'tag1'))
# E.g., (((32, 4, 8), 'rule'), ((32, 4, 4), 'rule'))
# E.g., (((32, 4, 8), 'rule0', 'tag0'),
# ((32, 4, 4), 'rule1', 'tag1'))
items = tuple(ParTileArg(*i) for i in items)
except TypeError:
# E.g., ((32, 4, 8), (32, 4, 4))
Expand Down
24 changes: 23 additions & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sympy
from sympy.core.add import _addsort
from sympy.core.mul import _keep_coeff, _mulsort
from sympy.core.core import ordering_of_classes
from sympy.core.decorators import call_highest_priority
from sympy.core.evalf import evalf_table

Expand Down Expand Up @@ -556,6 +557,9 @@ def __repr__(self):

__str__ = __repr__

def _sympystr(self, printer):
return str(self)

def _hashable_content(self):
return super()._hashable_content() + (self.dimensions,)

Expand Down Expand Up @@ -621,7 +625,7 @@ def __eq__(self, other):
__hash__ = sympy.Basic.__hash__

def _hashable_content(self):
return (self.name, self.dimension, hash(tuple(self.weights)))
return (self.name, self.dimension, str(self.weights))

@property
def dimension(self):
Expand Down Expand Up @@ -665,6 +669,20 @@ def __new__(cls, expr, mapper, **kwargs):
def _hashable_content(self):
return super()._hashable_content() + (self.mapper,)

def compare(self, other):
if self is other:
return 0
n1 = self.__class__
n2 = other.__class__
if n1.__name__ == n2.__name__:
return self.base.compare(other.base)
else:
return super().compare(other)

@cached_property
def base(self):
return self.expr.func(*[a for a in self.expr.args if a is not self.weights])

@property
def weights(self):
return self._weights
Expand Down Expand Up @@ -693,6 +711,10 @@ def _evaluate(self, **kwargs):
return expr


ordering_of_classes.insert(ordering_of_classes.index('Derivative') + 1,
'IndexDerivative')


class EvalDerivative(DifferentiableOp, sympy.Add):

is_commutative = True
Expand Down
16 changes: 11 additions & 5 deletions devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,11 @@ def generic_derivative(expr, dim, fd_order, deriv_order, matvec=direct, x0=None,
matvec, x0, symbolic, expand)


def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic, expand):
def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic,
expand):
# The stencil indices
indices, x0 = generate_indices(expr, dim, fd_order, side=side, matvec=matvec, x0=x0)
indices, x0 = generate_indices(expr, dim, fd_order, side=side, matvec=matvec,
x0=x0)

# Finite difference weights from Taylor approximation given these positions
if symbolic:
Expand All @@ -218,15 +220,19 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic
weights = numeric_weights(deriv_order, indices, x0)

# Enforce fixed precision FD coefficients to avoid variations in results
weights = [sympify(w).evalf(_PRECISION) for w in weights]
weights = [sympify(w).evalf(_PRECISION) for w in weights][::matvec.val]

# Transpose the FD, if necessary
if matvec:
indices = indices.scale(matvec.val)
if matvec == transpose:
indices = indices.transpose()

# Shift index due to staggering, if any
indices = indices.shift(-(expr.indices_ref[dim] - dim))

# The user may wish to restrict expansion to selected derivatives
if callable(expand):
expand = expand(dim)

if not expand and indices.expr is not None:
weights = Weights(name='w', dimensions=indices.free_dim, initvalue=weights)

Expand Down
16 changes: 11 additions & 5 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,27 +175,33 @@ def __repr__(self):
def spacing(self):
return self.dim.spacing

def scale(self, v):
def transpose(self):
"""
Construct a new IndexSet with all indices scaled by `v`.
Transpose the IndexSet.
"""
mapper = {self.spacing: v*self.spacing}
mapper = {self.spacing: -self.spacing}

indices = []
for i in self:
for i in reversed(self):
try:
iloc = i.xreplace(mapper)
except AttributeError:
# Pure number -> sympify
iloc = sympify(i).xreplace(mapper)
indices.append(iloc)

try:
free_dim = self.free_dim.transpose()
mapper.update({self.free_dim: -free_dim})
except AttributeError:
free_dim = self.free_dim

try:
expr = self.expr.xreplace(mapper)
except AttributeError:
expr = None

return IndexSet(self.dim, indices, expr=expr, fd=self.free_dim)
return IndexSet(self.dim, indices, expr=expr, fd=free_dim)

def shift(self, v):
"""
Expand Down
49 changes: 40 additions & 9 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import sympy

from devito.exceptions import InvalidOperator
from devito.ir.support import (Any, Backward, Forward, IterationSpace,
from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange,
pull_dims)
from devito.ir.clusters.analysis import analyze
from devito.ir.clusters.cluster import Cluster, ClusterGroup
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.symbolics import retrieve_indexed, uxreplace, xreplace_indices
from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten,
is_integer, timed_pass)
is_integer, timed_pass, toposort)
from devito.types import Array, Eq, Symbol
from devito.types.dimension import BOTTOM, ModuloDimension

Expand All @@ -29,6 +29,7 @@ def clusterize(exprs, **kwargs):
clusters = [Cluster(e, e.ispace) for e in exprs]

# Setup the IterationSpaces based on data dependence analysis
clusters = impose_total_ordering(clusters)
clusters = Schedule().process(clusters)

# Handle SteppingDimensions
Expand All @@ -49,6 +50,29 @@ def clusterize(exprs, **kwargs):
return ClusterGroup(clusters)


def impose_total_ordering(clusters):
"""
Create a new sequence of Clusters whose IterationSpaces are totally ordered
according to a global set of relations.
"""
global_relations = set().union(*[c.ispace.relations for c in clusters])
ordering = toposort(global_relations)

processed = []
for c in clusters:
key = lambda d: ordering.index(d)
try:
relations = {tuple(sorted(c.ispace.itdims, key=key))}
except ValueError:
# See issue #2204
relations = c.ispace.relations
ispace = c.ispace.reorder(relations=relations, mode='total')

processed.append(c.rebuild(ispace=ispace))

return processed


class Schedule(QueueStateful):

"""
Expand Down Expand Up @@ -121,10 +145,12 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
require_break = scope.d_flow.cause & maybe_break
if require_break:
backlog = [clusters[-1]] + backlog
# Try with increasingly smaller ClusterGroups until the ambiguity is gone
# Try with increasingly smaller ClusterGroups until the
# ambiguity is gone
return self.callback(clusters[:-1], prefix, backlog, require_break)

# Schedule Clusters over different IterationSpaces if this increases parallelism
# Schedule Clusters over different IterationSpaces if this increases
# parallelism
for i in range(1, len(clusters)):
if self._break_for_parallelism(scope, candidates, i):
return self.callback(clusters[:i], prefix, clusters[i:] + backlog,
Expand All @@ -146,8 +172,8 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
if not backlog:
return processed

# Handle the backlog -- the Clusters characterized by flow- and anti-dependences
# along one or more Dimensions
# Handle the backlog -- the Clusters characterized by flow- and
# anti-dependences along one or more Dimensions
idir = {d: Any for d in known_break}
stamp = Stamp()
for i, c in enumerate(list(backlog)):
Expand Down Expand Up @@ -278,7 +304,11 @@ def callback(self, clusters, prefix):
size = i.function.shape_allocated[d]
assert is_integer(size)

mapper[size][si].add(iaf)
# Resolve StencilDimensions in case of unexpanded expressions
# E.g. `i0 + t` -> `(t - 1, t, t + 1)`
iafs = erange(iaf)

mapper[size][si].update(iafs)

# Construct the ModuloDimensions
mds = []
Expand All @@ -288,7 +318,8 @@ def callback(self, clusters, prefix):
# SymPy's index ordering (t, t-1, t+1) afer modulo replacement so
# that associativity errors are consistent. This corresponds to
# sorting offsets {-1, 0, 1} as {0, -1, 1} assigning -inf to 0
siafs = sorted(iafs, key=lambda i: -np.inf if i - si == 0 else (i - si))
key = lambda i: -np.inf if i - si == 0 else (i - si)
siafs = sorted(iafs, key=key)

for iaf in siafs:
name = '%s%d' % (si.name, len(mds))
Expand Down Expand Up @@ -452,7 +483,7 @@ def normalize_reductions_dense(cluster, sregistry, options):
"""
opt_mapify_reduce = options['mapify-reduce']

dims = [d for d in cluster.properties.dimensions
dims = [d for d in cluster.ispace.itdims
if cluster.properties.is_parallel_atomic(d)]

if not dims:
Expand Down
Loading

0 comments on commit fb32972

Please sign in to comment.