Skip to content

Commit

Permalink
pySDC-build-in LagrangeApproximation class in SwitchEstimator (#…
Browse files Browse the repository at this point in the history
…406)

* SE now uses LagrangeApproximation class + removed Lagrange class in SE

* Removed log message again (not corresponding to PR)
  • Loading branch information
lisawim authored Mar 8, 2024
1 parent b31fd03 commit 82a6b73
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 62 deletions.
16 changes: 15 additions & 1 deletion pySDC/core/Lagrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class LagrangeApproximation(object):
The associated barycentric weights
"""

def __init__(self, points):
def __init__(self, points, fValues=None):
points = np.asarray(points).ravel()

diffs = points[:, None] - points[None, :]
Expand All @@ -110,6 +110,20 @@ def analytic(diffs):
self.points = points
self.weights = weights

# Store function values if provided
if fValues is not None:
fValues = np.asarray(fValues)
if fValues.shape != points.shape:
raise ValueError(f'fValues {fValues.shape} has not the correct shape: {points.shape}')
self.fValues = fValues

def __call__(self, t):
assert self.fValues is not None, "cannot evaluate polynomial without fValues"
t = np.asarray(t)
values = self.getInterpolationMatrix(t.ravel()).dot(self.fValues)
values.shape = t.shape
return values

@property
def n(self):
return self.points.size
Expand Down
64 changes: 3 additions & 61 deletions pySDC/projects/PinTSimE/switch_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pySDC.core.Collocation import CollBase
from pySDC.core.ConvergenceController import ConvergenceController, Status
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
from pySDC.core.Lagrange import LagrangeApproximation


class SwitchEstimator(ConvergenceController):
Expand Down Expand Up @@ -274,23 +275,8 @@ def get_switch(t_interp, state_function, m_guess):
Time point of found event.
"""

LagrangeInterpolator = LagrangeInterpolation(t_interp, state_function)

def p(t):
"""
Simplifies the call of the interpolant.
Parameters
----------
t : float
Time t at which the interpolant is called.
Returns
-------
p(t) : float
The value of the interpolated function at time t.
"""
return LagrangeInterpolator.eval(t)
LagrangeInterpolation = LagrangeApproximation(points=t_interp, fValues=state_function)
p = lambda t: LagrangeInterpolation.__call__(t)

def fprime(t):
"""
Expand Down Expand Up @@ -385,47 +371,3 @@ def newton(x0, p, fprime, newton_tol, newton_maxiter):
root = x0

return root


class LagrangeInterpolation(object):
def __init__(self, ti, yi):
"""Initialization routine"""
self.ti = np.asarray(ti)
self.yi = np.asarray(yi)
self.n = len(ti)

def get_Lagrange_polynomial(self, t, i):
"""
Computes the basis of the i-th Lagrange polynomial.
Parameters
----------
t : float
Time where the polynomial is computed at.
i : int
Index of the Lagrange polynomial
Returns
-------
product : float
The product of the bases.
"""
product = np.prod([(t - self.ti[k]) / (self.ti[i] - self.ti[k]) for k in range(self.n) if k != i])
return product

def eval(self, t):
"""
Evaluates the Lagrange interpolation at time t.
Parameters
----------
t : float
Time where interpolation is computed.
Returns
-------
p : float
Value of interpolant at time t.
"""
p = np.sum([self.yi[i] * self.get_Lagrange_polynomial(t, i) for i in range(self.n)])
return p

0 comments on commit 82a6b73

Please sign in to comment.