Skip to content

Commit

Permalink
Merge pull request #2453 from devitocodes/opt-fd-interp-tmp
Browse files Browse the repository at this point in the history
compiler: Minor tweaks for elastic code gen
  • Loading branch information
mloubout authored Sep 27, 2024
2 parents 25d87fc + 7aec615 commit 2ae6822
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 38 deletions.
4 changes: 4 additions & 0 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,10 @@ def __init_finalize__(self, **kwargs):
if not configuration['safe-math']:
self.cflags.append('--use_fast_math')

# Optionally print out per-kernel shared memory and register usage
if configuration['profiling'] == 'advanced2':
self.cflags.append('--ptxas-options=-v')

self.src_ext = 'cu'

# NOTE: not sure where we should place this. It definitely needs
Expand Down
2 changes: 1 addition & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,6 @@ def _(expr, x0, **kwargs):
if x0_expr:
dims = tuple((d, 0) for d in x0_expr)
fd_o = tuple([2]*len(dims))
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)._evaluate(**kwargs)
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)
else:
return expr
2 changes: 1 addition & 1 deletion devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def make_stencil_dimension(expr, _min, _max):
Create a StencilDimension for `expr` with unique name.
"""
n = len(expr.find(StencilDimension))
return StencilDimension(name='i%d' % n, _min=_min, _max=_max)
return StencilDimension('i%d' % n, _min, _max)


@cacheit
Expand Down
26 changes: 15 additions & 11 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from devito.mpi.reduction_scheme import DistReduce
from devito.symbolics import estimate_cost
from devito.tools import as_tuple, flatten, infer_dtype
from devito.types import WeakFence, CriticalRegion
from devito.types import Fence, WeakFence, CriticalRegion

__all__ = ["Cluster", "ClusterGroup"]

Expand Down Expand Up @@ -239,42 +239,46 @@ def is_sparse(self):
"""
return any(a.is_irregular for a in self.scope.accesses)

@property
@cached_property
def is_wild(self):
"""
True if encoding a non-mathematical operation, False otherwise.
"""
return self.is_halo_touch or self.is_dist_reduce or self.is_fence
return (self.is_halo_touch or
self.is_dist_reduce or
self.is_weak_fence or
self.is_critical_region)

@property
@cached_property
def is_halo_touch(self):
return self.exprs and all(isinstance(e.rhs, HaloTouch) for e in self.exprs)

@property
@cached_property
def is_dist_reduce(self):
return self.exprs and all(isinstance(e.rhs, DistReduce) for e in self.exprs)

@property
@cached_property
def is_fence(self):
return self.is_weak_fence or self.is_critical_region
return (self.exprs and all(isinstance(e.rhs, Fence) for e in self.exprs) or
self.is_critical_region)

@property
@cached_property
def is_weak_fence(self):
return self.exprs and all(isinstance(e.rhs, WeakFence) for e in self.exprs)

@property
@cached_property
def is_critical_region(self):
return self.exprs and all(isinstance(e.rhs, CriticalRegion) for e in self.exprs)

@property
@cached_property
def is_async(self):
"""
True if an asynchronous Cluster, False otherwise.
"""
return any(isinstance(s, (WithLock, PrefetchUpdate))
for s in flatten(self.syncs.values()))

@property
@cached_property
def is_wait(self):
"""
True if a Cluster waiting on a lock (that is a special synchronization
Expand Down
12 changes: 8 additions & 4 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def is_cross(source, sink):
return t0 < v <= t1 or t1 < v <= t0

for cg1 in cgroups[n+1:]:
n1 = cgroups.index(cg1)

# A Scope to compute all cross-ClusterGroup anti-dependences
scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross)

Expand All @@ -355,14 +357,16 @@ def is_cross(source, sink):
break

# Any anti- and iaw-dependences impose that `cg1` follows `cg0`
# and forbid any sort of fusion
elif any(scope.d_anti_gen()) or\
any(i.is_iaw for i in scope.d_output_gen()):
# and forbid any sort of fusion. Fences have the same effect
elif (any(scope.d_anti_gen()) or
any(i.is_iaw for i in scope.d_output_gen()) or
any(c.is_fence for c in flatten(cgroups[n:n1+1]))):
dag.add_edge(cg0, cg1)

# Any flow-dependences along an inner Dimension (i.e., a Dimension
# that doesn't appear in `prefix`) impose that `cg1` follows `cg0`
elif any(not (i.cause and i.cause & prefix) for i in scope.d_flow_gen()):
elif any(not (i.cause and i.cause & prefix)
for i in scope.d_flow_gen()):
dag.add_edge(cg0, cg1)

# Clearly, output dependences must be honored
Expand Down
24 changes: 20 additions & 4 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort

from devito.finite_differences.differentiable import EvalDerivative
from devito.finite_differences.differentiable import (
EvalDerivative, IndexDerivative
)
from devito.symbolics.extended_sympy import DefFunction, rfunc
from devito.symbolics.queries import q_leaf
from devito.symbolics.search import retrieve_indexed, retrieve_functions
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
from devito.types.basic import Basic
from devito.types.basic import Basic, Indexed
from devito.types.array import ComponentAccess
from devito.types.equation import Eq
from devito.types.relational import Le, Lt, Gt, Ge

__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched',
'evalrel', 'flatten_args']
'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite',
'reuse_if_untouched', 'evalrel', 'flatten_args']


def uxreplace(expr, rule):
Expand Down Expand Up @@ -246,6 +248,20 @@ def add(self, expr, make, terms=None):
self[base] = self.extracted[base] = make()


def subs_if_composite(expr, subs):
"""
Call `expr.subs(subs)` if `subs` contain composite expressions, that is
expressions that can be part of larger expressions of the same type (e.g.,
`a*b` could be part of `a*b*c`, while `a[1]` cannot be part of a "larger
Indexed"). Instead, if `subs` consists of just "primitive" expressions, then
resort to the much faster `uxreplace`.
"""
if all(isinstance(i, (Indexed, IndexDerivative)) for i in subs):
return uxreplace(expr, subs)
else:
return expr.subs(subs)


def xreplace_indices(exprs, mapper, key=None):
"""
Replace array indices in expressions.
Expand Down
6 changes: 6 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import inspect
from collections import namedtuple
from ctypes import POINTER, _Pointer, c_char_p, c_char
from functools import reduce, cached_property
Expand Down Expand Up @@ -490,6 +491,11 @@ def _cache_key(cls, *args, **kwargs):
# From the kwargs
key.update(kwargs)

# Any missing __rkwargs__ along with their default values
params = inspect.signature(cls.__init_finalize__).parameters
missing = [i for i in cls.__rkwargs__ if i in set(params).difference(key)]
key.update({i: params[i].default for i in missing})

return frozendict(key)

def __new__(cls, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,9 +1534,9 @@ class StencilDimension(BasicDimension):
__rargs__ = BasicDimension.__rargs__ + ('_min', '_max')
__rkwargs__ = BasicDimension.__rkwargs__ + ('step',)

def __init_finalize__(self, name, _min, _max, spacing=None, step=1,
def __init_finalize__(self, name, _min, _max, spacing=1, step=1,
**kwargs):
self._spacing = sympy.sympify(spacing) or sympy.S.One
self._spacing = sympy.sympify(spacing)

if not is_integer(_min):
raise ValueError("Expected integer `min` (got %s)" % _min)
Expand Down
33 changes: 18 additions & 15 deletions tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,34 @@ def test_interp():
a = Function(name="a", grid=grid, staggered=NODE)
sa = Function(name="as", grid=grid, staggered=x)

sp_diff = lambda a, b: sympy.simplify(a - b) == 0
def sp_diff(a, b):
a = getattr(a, 'evaluate', a)
b = getattr(b, 'evaluate', b)
return sympy.simplify(a - b) == 0

# Base case, no interp
assert interp_for_fd(a, {}, expand=True) == a
assert interp_for_fd(a, {x: x}, expand=True) == a
assert interp_for_fd(sa, {}, expand=True) == sa
assert interp_for_fd(sa, {x: x + x.spacing/2}, expand=True) == sa
assert interp_for_fd(a, {}) == a
assert interp_for_fd(a, {x: x}) == a
assert interp_for_fd(sa, {}) == sa
assert interp_for_fd(sa, {x: x + x.spacing/2}) == sa

# Base case, interp
assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}, expand=True),
assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}),
.5*a + .5*a.shift(x, x.spacing))
assert sp_diff(interp_for_fd(sa, {x: x}, expand=True),
assert sp_diff(interp_for_fd(sa, {x: x}),
.5*sa + .5*sa.shift(x, -x.spacing))

# Mul case, split interp
assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}, expand=True),
sa * interp_for_fd(a, {x: x + x.spacing/2}, expand=True))
assert sp_diff(interp_for_fd(a*sa, {x: x}, expand=True),
a * interp_for_fd(sa, {x: x}, expand=True))
assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}),
sa * interp_for_fd(a, {x: x + x.spacing/2}))
assert sp_diff(interp_for_fd(a*sa, {x: x}),
a * interp_for_fd(sa, {x: x}))

# Add case, split interp
assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}, expand=True),
sa + interp_for_fd(a, {x: x + x.spacing/2}, expand=True))
assert sp_diff(interp_for_fd(a + sa, {x: x}, expand=True),
a + interp_for_fd(sa, {x: x}, expand=True))
assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}),
sa + interp_for_fd(a, {x: x + x.spacing/2}))
assert sp_diff(interp_for_fd(a + sa, {x: x}),
a + interp_for_fd(sa, {x: x}))


@pytest.mark.parametrize('ndim', [1, 2, 3])
Expand Down
25 changes: 25 additions & 0 deletions tests/test_rebuild.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest

from devito import Dimension, Function
from devito.types import StencilDimension
from devito.data.allocators import DataReference


Expand Down Expand Up @@ -40,3 +42,26 @@ def test_w_new_dims(self):
assert f3.function is f3
assert f3.dimensions == dims0
assert np.all(f3.data[:] == 1)


class TestDimension:

def test_stencil_dimension(self):
sd0 = StencilDimension('i', 0, 1)
sd1 = StencilDimension('i', 0, 1)

# StencilDimensions are cached by devito so they are guaranteed to be
# unique for a given set of args/kwargs
assert sd0 is sd1

# Same applies to reconstruction
sd2 = sd0._rebuild()
assert sd0 is sd2

@pytest.mark.xfail(reason="Borked caching when supplying a kwarg for an arg")
def test_stencil_dimension_borked(self):
sd0 = StencilDimension('i', 0, _max=1)
sd1 = sd0._rebuild()

# TODO: Look into Symbol._cache_key and the way the key is generated
assert sd0 is sd1

0 comments on commit 2ae6822

Please sign in to comment.