Skip to content

Commit

Permalink
api: fix derivative kwargs at call
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 15, 2024
1 parent 4f6cc56 commit 2a85ad5
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 26 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/examples-mpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ jobs:
- name: Checkout devito
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.11

- name: Setup MPI
uses: mpi4py/setup-mpi@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:

- name: Install dependencies
run: |
pip install -e .[tests,extras]
pip install -e .
- name: Tests in examples
run: |
Expand Down
59 changes: 34 additions & 25 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,35 +219,44 @@ def _process_weights(cls, **kwargs):
return as_tuple(weights)

def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None):
side = side or self._side
method = method or self._method
weights = weights if weights is not None else self._weights

x0 = self._process_x0(self.dims, x0=x0)
_x0 = frozendict({**self.x0, **x0})

_fd_order = dict(self.fd_order.getters)
try:
_fd_order.update(fd_order or {})
except TypeError:
assert self.ndims == 1
_fd_order.update({self.dims[0]: fd_order or self.fd_order[0]})
except AttributeError:
raise TypeError("fd_order incompatible with dimensions")
rkw = {}
if side is not None:
rkw['side'] = side
if method is not None:
rkw['method'] = method
if weights is not None:
rkw['weights'] = weights

if x0 is not None:
x0 = self._process_x0(self.dims, x0=x0)
rkw['x0'] = frozendict({**self.x0, **x0})

if fd_order is not None:
try:
_fd_order = dict(fd_order)
except TypeError:
assert self.ndims == 1
_fd_order = {self.dims[0]: fd_order}
except AttributeError:
raise TypeError("fd_order incompatible with dimensions")

if isinstance(self.expr, Derivative):
# In case this was called on a perfect cross-derivative `u.dxdy`
# we need to propagate the call to the nested derivative
x0s = self._filter_dims(self.expr._filter_dims(_x0), neg=True)
expr = self.expr(x0=x0s, fd_order=self.expr._filter_dims(_fd_order),
side=side, method=method)
else:
expr = self.expr

_fd_order = self._filter_dims(_fd_order, as_tuple=True)

return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method,
weights=weights, expr=expr)
rkwe = dict(rkw)
if 'x0' in rkwe:
rkwe['x0'] = self._filter_dims(self.expr._filter_dims(rkw['x0']),
neg=True)
if fd_order is not None:
fdo = self.expr._filter_dims(_fd_order)
if fdo:
rkwe['fd_order'] = fdo
rkw['expr'] = self.expr(**rkwe)

if fd_order is not None:
rkw['fd_order'] = self._filter_dims(_fd_order, as_tuple=True)

return self._rebuild(**rkw)

def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
Expand Down
2 changes: 2 additions & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ dependencies:
- pip>=21.1.2
- pip:
- -r requirements.txt
- -r requirements-optional.txt
- -r requirements-testing.txt
18 changes: 18 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,24 @@ def test_cross_newnest(self):

assert f.dxdy == f.dx.dy

def test_nested_call(self):
grid = Grid((11, 11))
x, y = grid.dimensions
f = Function(name="f", grid=grid, space_order=8)

deriv = Derivative(f, x, y, deriv_order=(0, 0), fd_order=(2, 2),
x0={x: x-x.spacing/2, y: y+y.spacing/2}).dy(x0=y-y.spacing/2)

derivc = Derivative(f.dy(x0=y-y.spacing/2), x, y, deriv_order=(0, 0),
fd_order=(2, 2), x0={x: x-x.spacing/2, y: y+y.spacing/2})

assert deriv.expr.fd_order == (2, 2)
assert deriv.expr.deriv_order == (0, 0)
assert deriv.expr.x0 == {x: x-x.spacing/2, y: y+y.spacing/2}

# Should be commutative
assert simplify(deriv.evaluate - derivc.evaluate) == 0


class TestTwoStageEvaluation:

Expand Down

0 comments on commit 2a85ad5

Please sign in to comment.