Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Fix Sparse rebuild with explicit empty subfunctions #2473

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 61 additions & 41 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,26 @@ def __distributor_setup__(self, **kwargs):

return distributor

def __subfunc_setup__(self, key, suffix, dtype=None):
def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
key = None
for k in keys:
if k not in kwargs:
continue
elif kwargs[k] is None:
# In cases such as rebuild,
# the subfunction may be passed explicitly as None
return None
else:
key = kwargs[k]
break
else:
if inkwargs:
# Only create the subfunction if provided. This is useful
# with PrecomputedSparseFunctions that can have different subfunctions
# to skip creating extra if another one has already
# been provided
return None

# Shape and dimensions from args
name = '%s_%s' % (self.name, suffix)

Expand Down Expand Up @@ -603,11 +622,15 @@ def _dist_subfunc_gather(self, sfuncd, subfunc):
# in `_dist_scatter` is here received; a sparse point that is received in
# `_dist_scatter` is here sent.

def _dist_scatter(self, data=None):
def _dist_scatter(self, alias=None, data=None):
key = alias or self
mapper = {self: self._dist_data_scatter(data=data)}
for i in self._sub_functions:
if getattr(self, i) is not None:
mapper.update(self._dist_subfunc_scatter(getattr(self, i)))
if getattr(key, i) is not None:
# Pick up alias' in case runtime SparseFunctions is missing
# a subfunction
sf = getattr(self, i) or getattr(key, i)
mapper.update(self._dist_subfunc_scatter(sf))
return mapper

def _eval_at(self, func):
Expand All @@ -629,7 +652,7 @@ def _arg_defaults(self, alias=None):

# Add in the sparse data (as well as any SubFunction data) belonging to
# self's local domain only
for k, v in self._dist_scatter().items():
for k, v in self._dist_scatter(alias=alias).items():
args[mapper[k].name] = v
for i, s in zip(mapper[k].indices, v.shape):
args.update(i._arg_defaults(_min=0, size=s))
Expand All @@ -647,7 +670,7 @@ def _arg_values(self, **kwargs):
else:
# We've been provided a pure-data replacement (array)
values = {}
for k, v in self._dist_scatter(new).items():
for k, v in self._dist_scatter(data=new).items():
values[k.name] = v
for i, s in zip(k.indices, v.shape):
size = s - sum(k._size_nodomain[i])
Expand Down Expand Up @@ -844,8 +867,8 @@ def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)

# Set up sparse point coordinates
coordinates = kwargs.get('coordinates', kwargs.get('coordinates_data'))
self._coordinates = self.__subfunc_setup__(coordinates, 'coords')
keys = ('coordinates', 'coordinates_data')
self._coordinates = self.__subfunc_setup__('coords', keys, **kwargs)
self._dist_origin = {self._coordinates: self.grid.origin_offset}

def __interp_setup__(self, interpolation='linear', r=None, **kwargs):
Expand Down Expand Up @@ -1096,52 +1119,49 @@ class PrecomputedSparseFunction(AbstractSparseFunction):
def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)

# Process kwargs
coordinates = kwargs.get('coordinates', kwargs.get('coordinates_data'))
gridpoints = kwargs.get('gridpoints', kwargs.get('gridpoints_data'))
interpolation_coeffs = kwargs.get('interpolation_coeffs',
kwargs.get('interpolation_coeffs_data'))
if not any(k in kwargs for k in ('coordinates', 'gridpoints',
'coordinates_data', 'gridpoints_data')):
raise ValueError("PrecomputedSparseFunction requires `coordinates`"
"or `gridpoints` arguments")

# Subfunctions setup
self._dist_origin = {}
dtype = kwargs.pop('dtype', self.grid.dtype)
self._gridpoints = self.__subfunc_setup__('gridpoints',
('gridpoints', 'gridpoints_data'),
inkwargs=True,
dtype=np.int32, **kwargs)
self._coordinates = self.__subfunc_setup__('coords',
('coordinates', 'coordinates_data'),
inkwargs=self._gridpoints is not None,
dtype=dtype, **kwargs)

if self._coordinates is not None:
self._dist_origin.update({self._coordinates: self.grid.origin_offset})
if self._gridpoints is not None:
self._dist_origin.update({self._gridpoints: self.grid.origin_ioffset})

# Setup the interpolation coefficients. These are compulsory
ckeys = ('interpolation_coeffs', 'interpolation_coeffs_data')
self._interpolation_coeffs = \
self.__subfunc_setup__('interp_coeffs', ckeys, dtype=dtype, **kwargs)

# Grid points per sparse point (2 in the case of bilinear and trilinear)
r = kwargs.get('r')
if not is_integer(r):
raise TypeError('Need `r` int argument')
if r <= 0:
raise ValueError('`r` must be > 0')
# Make sure radius matches the coefficients size
if interpolation_coeffs is not None:
nr = interpolation_coeffs.shape[-1]
if any(c in kwargs for c in ckeys):
nr = self._interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
self._radius = r

if coordinates is not None and gridpoints is not None:
raise ValueError("Either `coordinates` or `gridpoints` must be "
"provided, but not both")

# Specifying only `npoints` is acceptable; this will require the user
# to setup the coordinates data later on
npoint = kwargs.get('npoint', None)
if self.npoint and coordinates is None and gridpoints is None:
coordinates = np.zeros((npoint, self.grid.dim))

if coordinates is not None:
self._coordinates = self.__subfunc_setup__(coordinates, 'coords')
self._gridpoints = None
self._dist_origin = {self._coordinates: self.grid.origin_offset}
else:
assert gridpoints is not None
self._coordinates = None
self._gridpoints = self.__subfunc_setup__(gridpoints, 'gridpoints',
dtype=np.int32)
self._dist_origin = {self._gridpoints: self.grid.origin_ioffset}

# Setup the interpolation coefficients. These are compulsory
self._interpolation_coeffs = \
self.__subfunc_setup__(interpolation_coeffs, 'interp_coeffs')
self._dist_origin.update({self._interpolation_coeffs: None})

self.interpolator = PrecomputedInterpolator(self)
Expand Down Expand Up @@ -2135,7 +2155,7 @@ def manual_scatter(self, *, data_all_zero=False):
**self._build_par_dim_to_nnz(scattered_gp, active_mrow),
}

def _dist_scatter(self, data=None):
def _dist_scatter(self, alias=None, data=None):
assert data is None
if self.scatter_result is None:
raise Exception("_dist_scatter called before manual_scatter called")
Expand Down
20 changes: 18 additions & 2 deletions tests/test_rebuild.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest

from devito import Dimension, Function
from devito.types import StencilDimension
from devito import Dimension, Function, Grid
from devito.types import StencilDimension, SparseFunction, PrecomputedSparseFunction
from devito.data.allocators import DataReference


Expand Down Expand Up @@ -65,3 +65,19 @@ def test_stencil_dimension_borked(self):

# TODO: Look into Symbol._cache_key and the way the key is generated
assert sd0 is sd1


class TestSparseFunction:

@pytest.mark.parametrize('sfunc', [SparseFunction, PrecomputedSparseFunction])
def test_none_subfunc(self, sfunc):
grid = Grid((4, 4))
coords = np.zeros((5, 2))

s = sfunc(name='s', grid=grid, npoint=5, coordinates=coords, r=1)

assert s.coordinates is not None

# Explicity set coordinates to None
sr = s._rebuild(function=None, initializer=None, coordinates=None)
assert sr.coordinates is None
Loading