Skip to content

Commit

Permalink
api: fix subfunction handling (subs/rebuild/...)
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 7, 2023
1 parent 3209435 commit 74a92ab
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 70 deletions.
7 changes: 5 additions & 2 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from devito.finite_differences import Differentiable, generate_fd_shortcuts
from devito.tools import (ReducerMap, as_tuple, c_restrict_void_p, flatten, is_integer,
memoized_meth, dtype_to_ctype, humanbytes)
from devito.types.dimension import Dimension, DynamicDimension
from devito.types.dimension import Dimension
from devito.types.args import ArgProvider
from devito.types.caching import CacheManager
from devito.types.basic import AbstractFunction, Size
Expand Down Expand Up @@ -1040,7 +1040,10 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = grid.dimensions

if args:
return tuple(dimensions), tuple(args)
assert len(args) == len(dimensions)
dims = tuple(a if isinstance(a, Dimension) else d
for (a, d) in zip(args, dimensions))
return tuple(dims), tuple(args)

# Staggered indices
staggered = kwargs.get("staggered", None)
Expand Down
112 changes: 46 additions & 66 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = (Dimension(name='p_%s' % kwargs["name"]),)

if args:
indices = args
assert len(args) == len(dimensions)
dims = tuple(a if isinstance(a, Dimension) else d
for (a, d) in zip(args, dimensions))
return tuple(dims), tuple(args)
else:
indices = dimensions

return dimensions, indices
return dimensions, dimensions

@classmethod
def __shape_setup__(cls, **kwargs):
Expand All @@ -80,16 +81,6 @@ def __shape_setup__(cls, **kwargs):
shape = (glb_npoint[grid.distributor.myrank],)
return shape

def func(self, *args, **kwargs):
# Rebuild subfunctions first to avoid new data creation as we have to use `_data`
# as a reconstruction kwargs to avoid the circular dependency
# with the parent in SubFunction
# This is also necessary to avoid shape issue in the SubFunction with mpi
for s in self._sub_functions:
if getattr(self, s) is not None:
kwargs.update({s: getattr(self, s).func(*args, **kwargs)})
return super().func(*args, **kwargs)

def __fd_setup__(self):
"""
Dynamically add derivative short-cuts.
Expand All @@ -108,24 +99,32 @@ def __distributor_setup__(self, **kwargs):
)

def __subfunc_setup__(self, key, suffix, dtype=None):
# Shape and dimensions from args
name = '%s_%s' % (self.name, suffix)
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

if key is not None and not isinstance(key, SubFunction):
key = np.array(key)

if key is not None and key.ndim > 2:
shape = (*shape, *key.shape[2:])
dimensions = (*dimensions, *mkdims("i", n=key.ndim-2))

# Check if already a SubFunction
if isinstance(key, SubFunction):
return key
# Need to rebuild so the dimensions match the parent SparseFunction
return key._rebuild(name=name, dimensions=dimensions, shape=shape,
alias=self.alias, halo=None)
elif key is not None and not isinstance(key, Iterable):
raise ValueError("`%s` must be either SubFunction "
"or iterable (e.g., list, np.ndarray)" % key)

name = '%s_%s' % (self.name, suffix)
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

if key is None:
# Fallback to default behaviour
dtype = dtype or self.dtype
else:
if key is not None:
key = np.array(key)

if (shape != key.shape[:2] and key.shape != (shape[1],)) and \
if (shape != key.shape and key.shape != (shape[1],)) and \
self._distributor.nprocs == 1:
raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" %
(suffix, key.shape[:2], shape))
Expand All @@ -136,12 +135,8 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
else:
dtype = dtype or self.dtype

if key is not None and key.ndim > 2:
shape = (*shape, *key.shape[2:])
dimensions = (*dimensions, *mkdims("i", n=key.ndim-2))

sf = SubFunction(
name=name, parent=self, dtype=dtype, dimensions=dimensions,
name=name, dtype=dtype, dimensions=dimensions,
shape=shape, space_order=0, initializer=key, alias=self.alias,
distributor=self._distributor
)
Expand Down Expand Up @@ -657,20 +652,6 @@ def time_dim(self):
"""The time Dimension."""
return self._time_dim

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
shape = kwargs.get('shape')
Expand All @@ -686,6 +667,21 @@ def __shape_setup__(cls, **kwargs):

return tuple(shape)

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]))

if args:
assert len(args) == len(dimensions)
dims = tuple(a if isinstance(a, Dimension) else d
for (a, d) in zip(args, dimensions))
return tuple(dims), tuple(args)
else:
return dimensions, dimensions

@property
def nt(self):
return self.shape[self._time_position]
Expand Down Expand Up @@ -1032,13 +1028,14 @@ def __init_finalize__(self, *args, **kwargs):
if r <= 0:
raise ValueError('`r` must be > 0')
# Make sure radius matches the coefficients size
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
if interpolation_coeffs is not None:
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
self._radius = r

if coordinates is not None and gridpoints is not None:
Expand Down Expand Up @@ -1680,23 +1677,6 @@ def inject(self, field, expr, u_t=None, p_t=None):

return out

@classmethod
def __indices_setup__(cls, *args, **kwargs):
"""
Return the default Dimension indices for a given data shape.
"""
dimensions = kwargs.get('dimensions')
if dimensions is None:
dimensions = (kwargs['grid'].time_dim, Dimension(
name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
# This happens before __init__, so we have to get 'npoint'
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_precomputed_sparse_function(self, mode, pickle):

sf = PrecomputedSparseTimeFunction(
name='sf', grid=grid, r=2, npoint=3, nt=5,
interpolation_coeffs=np.ndarray(shape=(3, 2, 2)), **kw
interpolation_coeffs=np.random.randn(3, 2, 2), **kw
)
sf.data[2, 1] = 5.

Expand Down
64 changes: 63 additions & 1 deletion tests/test_msparse.py → tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import numpy as np
import scipy.sparse

from devito import Grid, TimeFunction, Eq, Operator, MatrixSparseTimeFunction
from devito import Grid, TimeFunction, Eq, Operator, Dimension
from devito import (SparseFunction, SparseTimeFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction, MatrixSparseTimeFunction)


_sptypes = [SparseFunction, SparseTimeFunction,
PrecomputedSparseFunction, PrecomputedSparseTimeFunction]


class TestMatrixSparseTimeFunction(object):
Expand Down Expand Up @@ -394,5 +400,61 @@ def test_mpi(self):
assert sf.data[0, 0] == -3.0 # 1 * (1 * 1) * 1 + (-1) * (2 * 2) * 1


class TestSparseFunction(object):

@pytest.mark.parametrize('sptype', _sptypes)
def test_rebuild(self, sptype):
grid = Grid((3, 3, 3))
# Base object
sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2,
interpolation_coeffs=np.random.randn(1, 3, 2),
coordinates=np.random.randn(1, 3))

# Check subfunction setup
for subf in sp._sub_functions:
if getattr(sp, subf) is not None:
assert getattr(sp, subf).name.startswith("s_")

# Rebuild with different name, this should drop the function
# and create new data
sp2 = sp._rebuild(name="sr")

# Check new subfunction
for subf in sp2._sub_functions:
if getattr(sp2, subf) is not None:
assert getattr(sp2, subf).name.startswith("sr_")
assert np.all(getattr(sp2, subf).data == 0)

# Rebuild with different name as an alias
sp2 = sp._rebuild(name="sr2", alias=True)
for subf in sp2._sub_functions:
if getattr(sp2, subf) is not None:
assert getattr(sp2, subf).name.startswith("sr2_")
assert getattr(sp2, subf).data is None

@pytest.mark.parametrize('sptype', _sptypes)
def test_subs(self, sptype):
grid = Grid((3, 3, 3))
# Base object
sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2,
interpolation_coeffs=np.random.randn(1, 3, 2),
coordinates=np.random.randn(1, 3))

# Check subfunction setup
for subf in sp._sub_functions:
if getattr(sp, subf) is not None:
assert getattr(sp, subf).dimensions[0] == sp._sparse_dim

# Do substitution on sparse dimension
new_spdim = Dimension(name="newsp")

sps = sp._subs(sp._sparse_dim, new_spdim)
assert sps._sparse_dim == new_spdim
for subf in sps._sub_functions:
if getattr(sps, subf) is not None:
assert getattr(sps, subf).dimensions[0] == new_spdim
assert np.all(getattr(sps, subf).data == getattr(sp, subf).data)


if __name__ == "__main__":
TestMatrixSparseTimeFunction().test_mpi()

0 comments on commit 74a92ab

Please sign in to comment.