From 914de040e73a43547eb0c82b485d15eaa0dd8360 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 15 Oct 2024 11:22:09 -0400 Subject: [PATCH] api: fix derivative kwargs at call --- .github/workflows/examples.yml | 4 -- devito/finite_differences/derivative.py | 59 ++++++++++++++----------- environment-dev.yml | 2 + tests/test_derivatives.py | 18 ++++++++ 4 files changed, 54 insertions(+), 29 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 4df8a929b6..b04155b586 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -44,10 +44,6 @@ jobs: auto-activate-base: false python-version: 3.11 - - name: Install dependencies - run: | - pip install -e .[tests,extras] - - name: Tests in examples run: | py.test --cov --cov-config=.coveragerc --cov-report=xml examples/ diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 8c4a9a813e..1cf07642ed 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -219,35 +219,44 @@ def _process_weights(cls, **kwargs): return as_tuple(weights) def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None): - side = side or self._side - method = method or self._method - weights = weights if weights is not None else self._weights - - x0 = self._process_x0(self.dims, x0=x0) - _x0 = frozendict({**self.x0, **x0}) - - _fd_order = dict(self.fd_order.getters) - try: - _fd_order.update(fd_order or {}) - except TypeError: - assert self.ndims == 1 - _fd_order.update({self.dims[0]: fd_order or self.fd_order[0]}) - except AttributeError: - raise TypeError("fd_order incompatible with dimensions") + rkw = {} + if side is not None: + rkw['side'] = side + if method is not None: + rkw['method'] = method + if weights is not None: + rkw['weights'] = weights + + if x0 is not None: + x0 = self._process_x0(self.dims, x0=x0) + rkw['x0'] = frozendict({**self.x0, **x0}) + + if fd_order is not None: + try: + _fd_order = dict(fd_order) + except TypeError: + assert self.ndims == 1 + _fd_order = {self.dims[0]: fd_order} + except AttributeError: + raise TypeError("fd_order incompatible with dimensions") if isinstance(self.expr, Derivative): # In case this was called on a perfect cross-derivative `u.dxdy` # we need to propagate the call to the nested derivative - x0s = self._filter_dims(self.expr._filter_dims(_x0), neg=True) - expr = self.expr(x0=x0s, fd_order=self.expr._filter_dims(_fd_order), - side=side, method=method) - else: - expr = self.expr - - _fd_order = self._filter_dims(_fd_order, as_tuple=True) - - return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method, - weights=weights, expr=expr) + rkwe = dict(rkw) + if 'x0' in rkwe: + rkwe['x0'] = self._filter_dims(self.expr._filter_dims(rkw['x0']), + neg=True) + if fd_order is not None: + fdo = self.expr._filter_dims(_fd_order) + if fdo: + rkwe['fd_order'] = fdo + rkw['expr'] = self.expr(**rkwe) + + if fd_order is not None: + rkw['fd_order'] = self._filter_dims(_fd_order, as_tuple=True) + + return self._rebuild(**rkw) def _rebuild(self, *args, **kwargs): kwargs['preprocessed'] = True diff --git a/environment-dev.yml b/environment-dev.yml index 1327c1a2d5..4f32907712 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -6,3 +6,5 @@ dependencies: - pip>=21.1.2 - pip: - -r requirements.txt + - -r requirements-optional.txt + - -r requirements-testing.txt \ No newline at end of file diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 30893b2afb..d78a3a95fb 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -743,6 +743,24 @@ def test_cross_newnest(self): assert f.dxdy == f.dx.dy + def test_nested_call(self): + grid = Grid((11, 11)) + x, y = grid.dimensions + f = Function(name="f", grid=grid, space_order=8) + + deriv = Derivative(f, x, y, deriv_order=(0, 0), fd_order=(2, 2), + x0={x: x-x.spacing/2, y: y+y.spacing/2}).dy(x0=y-y.spacing/2) + + derivc = Derivative(f.dy(x0=y-y.spacing/2), x, y, deriv_order=(0, 0), + fd_order=(2, 2), x0={x: x-x.spacing/2, y: y+y.spacing/2}) + + assert deriv.expr.fd_order == (2, 2) + assert deriv.expr.deriv_order == (0, 0) + assert deriv.expr.x0 == {x: x-x.spacing/2, y: y+y.spacing/2} + + # Should be commutative + assert simplify(deriv.evaluate - derivc.evaluate) == 0 + class TestTwoStageEvaluation: