Skip to content

Commit

Permalink
Merge pull request #2198 from devitocodes/drop-sf-parent
Browse files Browse the repository at this point in the history
api: Cleanup and improve SubFunction
  • Loading branch information
mloubout authored Sep 8, 2023
2 parents a7c3446 + 68c29d4 commit 0d4c099
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 122 deletions.
17 changes: 10 additions & 7 deletions devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,16 @@ def run(expr):
terms.append(i)

# Collect common funcs
w_funcs = Add(*w_funcs, evaluate=False)
w_funcs = collect(w_funcs, funcs, evaluate=False)
try:
terms.extend([Mul(k, collect_const(v), evaluate=False)
for k, v in w_funcs.items()])
except AttributeError:
assert w_funcs == 0
if len(w_funcs) > 1:
w_funcs = Add(*w_funcs, evaluate=False)
w_funcs = collect(w_funcs, funcs, evaluate=False)
try:
terms.extend([Mul(k, collect_const(v), evaluate=False)
for k, v in w_funcs.items()])
except AttributeError:
assert w_funcs == 0
else:
terms.extend(w_funcs)

# Collect common pows
w_pows = Add(*w_pows, evaluate=False)
Expand Down
9 changes: 6 additions & 3 deletions devito/symbolics/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# * Number
# * Symbol
# * Indexed
extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject)
extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject,
IndexedPointer)


def q_symbol(expr):
Expand All @@ -31,7 +32,9 @@ def q_symbol(expr):


def q_leaf(expr):
return expr.is_Atom or expr.is_Indexed or isinstance(expr, extra_leaves)
return (expr.is_Atom or
expr.is_Indexed or
isinstance(expr, extra_leaves))


def q_indexed(expr):
Expand All @@ -51,7 +54,7 @@ def q_derivative(expr):
def q_terminal(expr):
return (expr.is_Symbol or
expr.is_Indexed or
isinstance(expr, extra_leaves + (IndexedPointer,)))
isinstance(expr, extra_leaves))


def q_routine(expr):
Expand Down
17 changes: 4 additions & 13 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,7 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = grid.dimensions

if args:
assert len(args) == len(dimensions)
return tuple(dimensions), tuple(args)

# Staggered indices
Expand Down Expand Up @@ -1449,16 +1450,10 @@ class SubFunction(Function):
"""
A Function bound to a "parent" DiscreteFunction.
A SubFunction hands control of argument binding and halo exchange to its
parent DiscreteFunction.
A SubFunction hands control of argument binding and halo exchange to the
DiscreteFunction it's bound to.
"""

__rkwargs__ = Function.__rkwargs__ + ('parent',)

def __init_finalize__(self, *args, **kwargs):
super(SubFunction, self).__init_finalize__(*args, **kwargs)
self._parent = kwargs['parent']

def __padding_setup__(self, **kwargs):
# SubFunctions aren't expected to be used in time-consuming loops
return tuple((0, 0) for i in range(self.ndim))
Expand All @@ -1470,12 +1465,8 @@ def _arg_values(self, **kwargs):
if self.name in kwargs:
raise RuntimeError("`%s` is a SubFunction, so it can't be assigned "
"a value dynamically" % self.name)
else:
return self._parent._arg_defaults(alias=self._parent).reduce_all()

@property
def parent(self):
return self._parent
return self._arg_defaults(alias=self)

@property
def origin(self):
Expand Down
142 changes: 66 additions & 76 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = (Dimension(name='p_%s' % kwargs["name"]),)

if args:
indices = args
return tuple(dimensions), tuple(args)
else:
indices = dimensions

return dimensions, indices
return dimensions, dimensions

@classmethod
def __shape_setup__(cls, **kwargs):
Expand All @@ -80,16 +78,6 @@ def __shape_setup__(cls, **kwargs):
shape = (glb_npoint[grid.distributor.myrank],)
return shape

def func(self, *args, **kwargs):
# Rebuild subfunctions first to avoid new data creation as we have to use `_data`
# as a reconstruction kwargs to avoid the circular dependency
# with the parent in SubFunction
# This is also necessary to avoid shape issue in the SubFunction with mpi
for s in self._sub_functions:
if getattr(self, s) is not None:
kwargs.update({s: getattr(self, s).func(*args, **kwargs)})
return super().func(*args, **kwargs)

def __fd_setup__(self):
"""
Dynamically add derivative short-cuts.
Expand All @@ -108,24 +96,39 @@ def __distributor_setup__(self, **kwargs):
)

def __subfunc_setup__(self, key, suffix, dtype=None):
# Shape and dimensions from args
name = '%s_%s' % (self.name, suffix)

if key is not None and not isinstance(key, SubFunction):
key = np.array(key)

if key is not None:
dimensions = (self._sparse_dim, Dimension(name='d'))
if key.ndim > 2:
dimensions = (self._sparse_dim, Dimension(name='d'),
*mkdims("i", n=key.ndim-2))
else:
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim, *key.shape[2:])
else:
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

# Check if already a SubFunction
if isinstance(key, SubFunction):
return key
# Need to rebuild so the dimensions match the parent SparseFunction
indices = (self.indices[self._sparse_position], *key.indices[1:])
return key._rebuild(*indices, name=name, shape=shape,
alias=self.alias, halo=None)
elif key is not None and not isinstance(key, Iterable):
raise ValueError("`%s` must be either SubFunction "
"or iterable (e.g., list, np.ndarray)" % key)

name = '%s_%s' % (self.name, suffix)
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

if key is None:
# Fallback to default behaviour
dtype = dtype or self.dtype
else:
if key is not None:
key = np.array(key)

if (shape != key.shape[:2] and key.shape != (shape[1],)) and \
if (shape != key.shape and key.shape != (shape[1],)) and \
self._distributor.nprocs == 1:
raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" %
(suffix, key.shape[:2], shape))
Expand All @@ -136,12 +139,8 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
else:
dtype = dtype or self.dtype

if key is not None and key.ndim > 2:
shape = (*shape, *key.shape[2:])
dimensions = (*dimensions, *mkdims("i", n=key.ndim-2))

sf = SubFunction(
name=name, parent=self, dtype=dtype, dimensions=dimensions,
name=name, dtype=dtype, dimensions=dimensions,
shape=shape, space_order=0, initializer=key, alias=self.alias,
distributor=self._distributor
)
Expand Down Expand Up @@ -657,20 +656,6 @@ def time_dim(self):
"""The time Dimension."""
return self._time_dim

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
shape = kwargs.get('shape')
Expand All @@ -686,6 +671,18 @@ def __shape_setup__(cls, **kwargs):

return tuple(shape)

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]),)

if args:
return tuple(dimensions), tuple(args)
else:
return dimensions, dimensions

@property
def nt(self):
return self.shape[self._time_position]
Expand Down Expand Up @@ -791,7 +788,7 @@ class SparseFunction(AbstractSparseFunction):

_sub_functions = ('coordinates',)

__rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates_data',)
__rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates',)

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)
Expand Down Expand Up @@ -1014,8 +1011,8 @@ class PrecomputedSparseFunction(AbstractSparseFunction):
_sub_functions = ('gridpoints', 'coordinates', 'interpolation_coeffs')

__rkwargs__ = (AbstractSparseFunction.__rkwargs__ +
('r', 'gridpoints_data', 'coordinates_data',
'interpolation_coeffs_data'))
('r', 'gridpoints', 'coordinates',
'interpolation_coeffs'))

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)
Expand All @@ -1032,13 +1029,14 @@ def __init_finalize__(self, *args, **kwargs):
if r <= 0:
raise ValueError('`r` must be > 0')
# Make sure radius matches the coefficients size
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
if interpolation_coeffs is not None:
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
self._radius = r

if coordinates is not None and gridpoints is not None:
Expand Down Expand Up @@ -1179,6 +1177,15 @@ class PrecomputedSparseTimeFunction(AbstractSparseTimeFunction,
PrecomputedSparseFunction.__rkwargs__))


# *** MatrixSparse*Function API
# This is mostly legacy stuff which often escapes the devito's modus operandi

class DynamicSubFunction(SubFunction):

def _arg_defaults(self, **kwargs):
return {}


class MatrixSparseTimeFunction(AbstractSparseTimeFunction):
"""
A specialised type of SparseTimeFunction where the interpolation is externally
Expand Down Expand Up @@ -1378,7 +1385,7 @@ def __init_finalize__(self, *args, **kwargs):
else:
nnz_size = 1

self._mrow = SubFunction(
self._mrow = DynamicSubFunction(
name='mrow_%s' % self.name,
dtype=np.int32,
dimensions=(self.nnzdim,),
Expand All @@ -1387,7 +1394,7 @@ def __init_finalize__(self, *args, **kwargs):
parent=self,
allocator=self._allocator,
)
self._mcol = SubFunction(
self._mcol = DynamicSubFunction(
name='mcol_%s' % self.name,
dtype=np.int32,
dimensions=(self.nnzdim,),
Expand All @@ -1396,7 +1403,7 @@ def __init_finalize__(self, *args, **kwargs):
parent=self,
allocator=self._allocator,
)
self._mval = SubFunction(
self._mval = DynamicSubFunction(
name='mval_%s' % self.name,
dtype=self.dtype,
dimensions=(self.nnzdim,),
Expand All @@ -1413,8 +1420,8 @@ def __init_finalize__(self, *args, **kwargs):
self.par_dim_to_nnz_dim = DynamicDimension('par_dim_to_nnz_%s' % self.name)

# This map acts as an indirect sort of the sources according to their
# position along the parallelisation Dimension
self._par_dim_to_nnz_map = SubFunction(
# position along the parallelisation dimension
self._par_dim_to_nnz_map = DynamicSubFunction(
name='par_dim_to_nnz_map_%s' % self.name,
dtype=np.int32,
dimensions=(self.par_dim_to_nnz_dim,),
Expand All @@ -1423,7 +1430,7 @@ def __init_finalize__(self, *args, **kwargs):
space_order=0,
parent=self,
)
self._par_dim_to_nnz_m = SubFunction(
self._par_dim_to_nnz_m = DynamicSubFunction(
name='par_dim_to_nnz_m_%s' % self.name,
dtype=np.int32,
dimensions=(self._par_dim,),
Expand All @@ -1432,7 +1439,7 @@ def __init_finalize__(self, *args, **kwargs):
space_order=0,
parent=self,
)
self._par_dim_to_nnz_M = SubFunction(
self._par_dim_to_nnz_M = DynamicSubFunction(
name='par_dim_to_nnz_M_%s' % self.name,
dtype=np.int32,
dimensions=(self._par_dim,),
Expand Down Expand Up @@ -1671,23 +1678,6 @@ def inject(self, field, expr, u_t=None, p_t=None):

return out

@classmethod
def __indices_setup__(cls, *args, **kwargs):
"""
Return the default Dimension indices for a given data shape.
"""
dimensions = kwargs.get('dimensions')
if dimensions is None:
dimensions = (kwargs['grid'].time_dim, Dimension(
name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
# This happens before __init__, so we have to get 'npoint'
Expand Down
22 changes: 9 additions & 13 deletions examples/seismic/inversion/inversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,15 @@ def compute_residual(res, dobs, dsyn):
"""
Computes the data residual dsyn - dobs into residual
"""
if res.grid.distributor.is_parallel:
# If we run with MPI, we have to compute the residual via an operator
# First make sure we can take the difference and that receivers are at the
# same position
assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data)
assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data)
# Create a difference operator
diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) -
dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]}))
Operator(diff_eq)()
else:
# A simple data difference is enough in serial
res.data[:] = dsyn.data[:] - dobs.data[:]
# If we run with MPI, we have to compute the residual via an operator
# First make sure we can take the difference and that receivers are at the
# same position
assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data)
assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data)
# Create a difference operator
diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) -
dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]}))
Operator(diff_eq)()

return res

Expand Down
Loading

0 comments on commit 0d4c099

Please sign in to comment.