From ad7c55b42acf99681fecae237319cfbf7e60797b Mon Sep 17 00:00:00 2001 From: mloubout Date: Sat, 19 Oct 2024 22:46:48 -0400 Subject: [PATCH] api: enforce sympy shifts --- devito/finite_differences/tools.py | 10 +++++----- devito/symbolics/printer.py | 13 +++++++++++++ tests/test_tensors.py | 5 +++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 8a879923df..35c065b607 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -2,7 +2,7 @@ from itertools import product import numpy as np -from sympy import S, finite_diff_weights, cacheit, sympify, Function +from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational from devito.tools import Tag, as_tuple from devito.types.dimension import StencilDimension @@ -308,13 +308,13 @@ def make_shift_x0(shift, ndim): """ if shift is None: return lambda s, d, i, j: None - elif isinstance(shift, float): - return lambda s, d, i, j: d + s * d.spacing + elif sympify(shift).is_Number: + return lambda s, d, i, j: d + Rational(s) * d.spacing elif type(shift) is tuple and np.shape(shift) == ndim: if len(ndim) == 1: - return lambda s, d, i, j: d + s[j] * d.spacing + return lambda s, d, i, j: d + Rational(s[j]) * d.spacing elif len(ndim) == 2: - return lambda s, d, i, j: d + s[i][j] * d.spacing + return lambda s, d, i, j: d + Rational(s[i][j]) * d.spacing else: raise ValueError("ndim length must be equal to 1 or 2") raise ValueError("shift parameter must be one of the following options: " diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 672971cf4f..f60feee443 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -7,6 +7,8 @@ from mpmath.libmp import prec_to_dps, to_str from packaging.version import Version + +from sympy.codegen.ast import float32, float64 from sympy.logic.boolalg import BooleanFunction from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from sympy.printing.c import C99CodePrinter @@ -18,6 +20,9 @@ __all__ = ['ccode'] +_type_mapper = {np.float32: float32, np.float64: float64} + + class CodePrinter(C99CodePrinter): """ @@ -179,12 +184,20 @@ def _print_Add(self, expr, order=None): def _print_Float(self, expr): """Print a Float in C-like scientific notation.""" + try: + # Make sure the float is in the correct format + expr = _type_mapper[self.dtype].cast_nocheck(expr) + rv = str(expr) + except KeyError: + pass + prec = expr._prec if prec < 5: dps = 0 else: dps = prec_to_dps(expr._prec) + if self._settings["full_prec"] is True: strip = False elif self._settings["full_prec"] is False: diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 15e18ababd..48b4e57d53 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -1,5 +1,6 @@ import numpy as np import sympy +from sympy import Rational import pytest @@ -383,8 +384,8 @@ def test_shifted_lap_of_tensor(shift, ndim): for j in range(ndim): ref = 0 for i, d in enumerate(v.space_dimensions): - x0 = (None if shift is None else d + shift[i][j] * d.spacing if - type(shift) is tuple else d + shift * d.spacing) + x0 = (None if shift is None else d + Rational(shift[i][j]) * d.spacing if + type(shift) is tuple else d + Rational(shift) * d.spacing) ref += getattr(v[j, i], 'd%s2' % d.name)(x0=x0, fd_order=order) assert df[j] == ref