From 74a92abba22bebb2344f037cc8d2c50d60418e6d Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 7 Sep 2023 09:53:00 -0400 Subject: [PATCH] api: fix subfunction handling (subs/rebuild/...) --- devito/types/dense.py | 7 +- devito/types/sparse.py | 112 +++++++++------------- tests/test_pickle.py | 2 +- tests/{test_msparse.py => test_sparse.py} | 64 ++++++++++++- 4 files changed, 115 insertions(+), 70 deletions(-) rename tests/{test_msparse.py => test_sparse.py} (84%) diff --git a/devito/types/dense.py b/devito/types/dense.py index 4c912d2704a..b454075edc3 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -20,7 +20,7 @@ from devito.finite_differences import Differentiable, generate_fd_shortcuts from devito.tools import (ReducerMap, as_tuple, c_restrict_void_p, flatten, is_integer, memoized_meth, dtype_to_ctype, humanbytes) -from devito.types.dimension import Dimension, DynamicDimension +from devito.types.dimension import Dimension from devito.types.args import ArgProvider from devito.types.caching import CacheManager from devito.types.basic import AbstractFunction, Size @@ -1040,7 +1040,10 @@ def __indices_setup__(cls, *args, **kwargs): dimensions = grid.dimensions if args: - return tuple(dimensions), tuple(args) + assert len(args) == len(dimensions) + dims = tuple(a if isinstance(a, Dimension) else d + for (a, d) in zip(args, dimensions)) + return tuple(dims), tuple(args) # Staggered indices staggered = kwargs.get("staggered", None) diff --git a/devito/types/sparse.py b/devito/types/sparse.py index f036a68c9ca..a88d8d1e299 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -61,11 +61,12 @@ def __indices_setup__(cls, *args, **kwargs): dimensions = (Dimension(name='p_%s' % kwargs["name"]),) if args: - indices = args + assert len(args) == len(dimensions) + dims = tuple(a if isinstance(a, Dimension) else d + for (a, d) in zip(args, dimensions)) + return tuple(dims), tuple(args) else: - indices = dimensions - - return dimensions, indices + return dimensions, dimensions @classmethod def __shape_setup__(cls, **kwargs): @@ -80,16 +81,6 @@ def __shape_setup__(cls, **kwargs): shape = (glb_npoint[grid.distributor.myrank],) return shape - def func(self, *args, **kwargs): - # Rebuild subfunctions first to avoid new data creation as we have to use `_data` - # as a reconstruction kwargs to avoid the circular dependency - # with the parent in SubFunction - # This is also necessary to avoid shape issue in the SubFunction with mpi - for s in self._sub_functions: - if getattr(self, s) is not None: - kwargs.update({s: getattr(self, s).func(*args, **kwargs)}) - return super().func(*args, **kwargs) - def __fd_setup__(self): """ Dynamically add derivative short-cuts. @@ -108,24 +99,32 @@ def __distributor_setup__(self, **kwargs): ) def __subfunc_setup__(self, key, suffix, dtype=None): + # Shape and dimensions from args + name = '%s_%s' % (self.name, suffix) + dimensions = (self._sparse_dim, Dimension(name='d')) + shape = (self.npoint, self.grid.dim) + + if key is not None and not isinstance(key, SubFunction): + key = np.array(key) + + if key is not None and key.ndim > 2: + shape = (*shape, *key.shape[2:]) + dimensions = (*dimensions, *mkdims("i", n=key.ndim-2)) + + # Check if already a SubFunction if isinstance(key, SubFunction): - return key + # Need to rebuild so the dimensions match the parent SparseFunction + return key._rebuild(name=name, dimensions=dimensions, shape=shape, + alias=self.alias, halo=None) elif key is not None and not isinstance(key, Iterable): raise ValueError("`%s` must be either SubFunction " "or iterable (e.g., list, np.ndarray)" % key) - name = '%s_%s' % (self.name, suffix) - dimensions = (self._sparse_dim, Dimension(name='d')) - shape = (self.npoint, self.grid.dim) - if key is None: # Fallback to default behaviour dtype = dtype or self.dtype else: - if key is not None: - key = np.array(key) - - if (shape != key.shape[:2] and key.shape != (shape[1],)) and \ + if (shape != key.shape and key.shape != (shape[1],)) and \ self._distributor.nprocs == 1: raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" % (suffix, key.shape[:2], shape)) @@ -136,12 +135,8 @@ def __subfunc_setup__(self, key, suffix, dtype=None): else: dtype = dtype or self.dtype - if key is not None and key.ndim > 2: - shape = (*shape, *key.shape[2:]) - dimensions = (*dimensions, *mkdims("i", n=key.ndim-2)) - sf = SubFunction( - name=name, parent=self, dtype=dtype, dimensions=dimensions, + name=name, dtype=dtype, dimensions=dimensions, shape=shape, space_order=0, initializer=key, alias=self.alias, distributor=self._distributor ) @@ -657,20 +652,6 @@ def time_dim(self): """The time Dimension.""" return self._time_dim - @classmethod - def __indices_setup__(cls, *args, **kwargs): - dimensions = as_tuple(kwargs.get('dimensions')) - if not dimensions: - dimensions = (kwargs['grid'].time_dim, - Dimension(name='p_%s' % kwargs["name"])) - - if args: - indices = args - else: - indices = dimensions - - return dimensions, indices - @classmethod def __shape_setup__(cls, **kwargs): shape = kwargs.get('shape') @@ -686,6 +667,21 @@ def __shape_setup__(cls, **kwargs): return tuple(shape) + @classmethod + def __indices_setup__(cls, *args, **kwargs): + dimensions = as_tuple(kwargs.get('dimensions')) + if not dimensions: + dimensions = (kwargs['grid'].time_dim, + Dimension(name='p_%s' % kwargs["name"])) + + if args: + assert len(args) == len(dimensions) + dims = tuple(a if isinstance(a, Dimension) else d + for (a, d) in zip(args, dimensions)) + return tuple(dims), tuple(args) + else: + return dimensions, dimensions + @property def nt(self): return self.shape[self._time_position] @@ -1032,13 +1028,14 @@ def __init_finalize__(self, *args, **kwargs): if r <= 0: raise ValueError('`r` must be > 0') # Make sure radius matches the coefficients size - nr = interpolation_coeffs.shape[-1] - if nr // 2 != r: - if nr == r: - r = r // 2 - else: - raise ValueError("Interpolation coefficients shape %d do " - "not match specified radius %d" % (r, nr)) + if interpolation_coeffs is not None: + nr = interpolation_coeffs.shape[-1] + if nr // 2 != r: + if nr == r: + r = r // 2 + else: + raise ValueError("Interpolation coefficients shape %d do " + "not match specified radius %d" % (r, nr)) self._radius = r if coordinates is not None and gridpoints is not None: @@ -1680,23 +1677,6 @@ def inject(self, field, expr, u_t=None, p_t=None): return out - @classmethod - def __indices_setup__(cls, *args, **kwargs): - """ - Return the default Dimension indices for a given data shape. - """ - dimensions = kwargs.get('dimensions') - if dimensions is None: - dimensions = (kwargs['grid'].time_dim, Dimension( - name='p_%s' % kwargs["name"])) - - if args: - indices = args - else: - indices = dimensions - - return dimensions, indices - @classmethod def __shape_setup__(cls, **kwargs): # This happens before __init__, so we have to get 'npoint' diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 62423b2c158..16f44bdaede 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -111,7 +111,7 @@ def test_precomputed_sparse_function(self, mode, pickle): sf = PrecomputedSparseTimeFunction( name='sf', grid=grid, r=2, npoint=3, nt=5, - interpolation_coeffs=np.ndarray(shape=(3, 2, 2)), **kw + interpolation_coeffs=np.random.randn(3, 2, 2), **kw ) sf.data[2, 1] = 5. diff --git a/tests/test_msparse.py b/tests/test_sparse.py similarity index 84% rename from tests/test_msparse.py rename to tests/test_sparse.py index 5cbfde848a0..51ef321c0bc 100644 --- a/tests/test_msparse.py +++ b/tests/test_sparse.py @@ -4,7 +4,13 @@ import numpy as np import scipy.sparse -from devito import Grid, TimeFunction, Eq, Operator, MatrixSparseTimeFunction +from devito import Grid, TimeFunction, Eq, Operator, Dimension +from devito import (SparseFunction, SparseTimeFunction, PrecomputedSparseFunction, + PrecomputedSparseTimeFunction, MatrixSparseTimeFunction) + + +_sptypes = [SparseFunction, SparseTimeFunction, + PrecomputedSparseFunction, PrecomputedSparseTimeFunction] class TestMatrixSparseTimeFunction(object): @@ -394,5 +400,61 @@ def test_mpi(self): assert sf.data[0, 0] == -3.0 # 1 * (1 * 1) * 1 + (-1) * (2 * 2) * 1 +class TestSparseFunction(object): + + @pytest.mark.parametrize('sptype', _sptypes) + def test_rebuild(self, sptype): + grid = Grid((3, 3, 3)) + # Base object + sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2, + interpolation_coeffs=np.random.randn(1, 3, 2), + coordinates=np.random.randn(1, 3)) + + # Check subfunction setup + for subf in sp._sub_functions: + if getattr(sp, subf) is not None: + assert getattr(sp, subf).name.startswith("s_") + + # Rebuild with different name, this should drop the function + # and create new data + sp2 = sp._rebuild(name="sr") + + # Check new subfunction + for subf in sp2._sub_functions: + if getattr(sp2, subf) is not None: + assert getattr(sp2, subf).name.startswith("sr_") + assert np.all(getattr(sp2, subf).data == 0) + + # Rebuild with different name as an alias + sp2 = sp._rebuild(name="sr2", alias=True) + for subf in sp2._sub_functions: + if getattr(sp2, subf) is not None: + assert getattr(sp2, subf).name.startswith("sr2_") + assert getattr(sp2, subf).data is None + + @pytest.mark.parametrize('sptype', _sptypes) + def test_subs(self, sptype): + grid = Grid((3, 3, 3)) + # Base object + sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2, + interpolation_coeffs=np.random.randn(1, 3, 2), + coordinates=np.random.randn(1, 3)) + + # Check subfunction setup + for subf in sp._sub_functions: + if getattr(sp, subf) is not None: + assert getattr(sp, subf).dimensions[0] == sp._sparse_dim + + # Do substitution on sparse dimension + new_spdim = Dimension(name="newsp") + + sps = sp._subs(sp._sparse_dim, new_spdim) + assert sps._sparse_dim == new_spdim + for subf in sps._sub_functions: + if getattr(sps, subf) is not None: + assert getattr(sps, subf).dimensions[0] == new_spdim + assert np.all(getattr(sps, subf).data == getattr(sp, subf).data) + + if __name__ == "__main__": TestMatrixSparseTimeFunction().test_mpi()