Skip to content

Commit

Permalink
api: enforce sympy shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 20, 2024
1 parent 7795225 commit ad7c55b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
10 changes: 5 additions & 5 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from itertools import product

import numpy as np
from sympy import S, finite_diff_weights, cacheit, sympify, Function
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational

from devito.tools import Tag, as_tuple
from devito.types.dimension import StencilDimension
Expand Down Expand Up @@ -308,13 +308,13 @@ def make_shift_x0(shift, ndim):
"""
if shift is None:
return lambda s, d, i, j: None
elif isinstance(shift, float):
return lambda s, d, i, j: d + s * d.spacing
elif sympify(shift).is_Number:
return lambda s, d, i, j: d + Rational(s) * d.spacing
elif type(shift) is tuple and np.shape(shift) == ndim:
if len(ndim) == 1:
return lambda s, d, i, j: d + s[j] * d.spacing
return lambda s, d, i, j: d + Rational(s[j]) * d.spacing
elif len(ndim) == 2:
return lambda s, d, i, j: d + s[i][j] * d.spacing
return lambda s, d, i, j: d + Rational(s[i][j]) * d.spacing
else:
raise ValueError("ndim length must be equal to 1 or 2")
raise ValueError("shift parameter must be one of the following options: "
Expand Down
13 changes: 13 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from mpmath.libmp import prec_to_dps, to_str
from packaging.version import Version

from sympy.codegen.ast import float32, float64
from sympy.logic.boolalg import BooleanFunction
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter
Expand All @@ -18,6 +20,9 @@
__all__ = ['ccode']


_type_mapper = {np.float32: float32, np.float64: float64}


class CodePrinter(C99CodePrinter):

"""
Expand Down Expand Up @@ -179,12 +184,20 @@ def _print_Add(self, expr, order=None):

def _print_Float(self, expr):
"""Print a Float in C-like scientific notation."""
try:
# Make sure the float is in the correct format
expr = _type_mapper[self.dtype].cast_nocheck(expr)
rv = str(expr)
except KeyError:
pass

prec = expr._prec

if prec < 5:
dps = 0
else:
dps = prec_to_dps(expr._prec)

if self._settings["full_prec"] is True:
strip = False
elif self._settings["full_prec"] is False:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_tensors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import sympy
from sympy import Rational

import pytest

Expand Down Expand Up @@ -383,8 +384,8 @@ def test_shifted_lap_of_tensor(shift, ndim):
for j in range(ndim):
ref = 0
for i, d in enumerate(v.space_dimensions):
x0 = (None if shift is None else d + shift[i][j] * d.spacing if
type(shift) is tuple else d + shift * d.spacing)
x0 = (None if shift is None else d + Rational(shift[i][j]) * d.spacing if
type(shift) is tuple else d + Rational(shift) * d.spacing)
ref += getattr(v[j, i], 'd%s2' % d.name)(x0=x0, fd_order=order)
assert df[j] == ref

Expand Down

0 comments on commit ad7c55b

Please sign in to comment.