Skip to content

Commit

Permalink
Datatype DAEMesh for DAEs (#384)
Browse files Browse the repository at this point in the history
* Added DAE mesh

* Updated all DAE problems and the SDC-DAE sweeper

* Updated playgrounds with new DAE datatype

* Adapted tests

* Minor changes

* Black.. :o

* Added DAEMesh only to semi-explicit DAEs + update for FI-SDC and ProblemDAE.py

* Black :D

* Removed unnecessary approx_solution hook + replaced by LogSolution hook

* Update WSCC9 problem class

* Removed unnecessary comments

* Removed test_misc.py

* Removed registering of newton_tol from child classes

* Update test_problems.py

* Rename error hook class for logging global error in differential variable(s)

* Added MultiComponentMesh - @brownbaerchen + @tlunet + @pancetta Thank ugit add pySDC/implementations/datatype_classes/MultiComponentMesh.py

* Updated stuff with new version of DAE data type

* (Hopefully) faster test for WSCC9

* Test for DAEMesh

* Renaming

* ..for DAEMesh.py

* Bug fix

* Another bug fix..

* Preparation for PDAE stuff (?)

* Changes + adapted first test for PDAE stuff

* Commented out test_WSCC9_SDC_detection() - too long runtime

* Minor changes for test_DAEMesh.py

* Extended test for DAEMesh - credits for @brownbaerchen

* Test for HookClass_DAE.py

* Update for DAEMesh + tests

* 🎉 - speed up test a bit (at least locally..)

* Forgot to enable other tests again

* Removed if-else-statements for mesh type

* View for unknowns in implSysFlatten
  • Loading branch information
lisawim authored Feb 27, 2024
1 parent 05adc5a commit b059c62
Show file tree
Hide file tree
Showing 18 changed files with 471 additions and 303 deletions.
12 changes: 12 additions & 0 deletions pySDC/projects/DAE/misc/DAEMesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pySDC.implementations.datatype_classes.mesh import MultiComponentMesh


class DAEMesh(MultiComponentMesh):
r"""
Datatype for DAE problems. The solution of the problem can be splitted in the differential part
and in an algebraic part.
This data type can be used for the solution of the problem itself as well as for its derivative.
"""

components = ['diff', 'alg']
74 changes: 36 additions & 38 deletions pySDC/projects/DAE/misc/HookClass_DAE.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,82 @@
from pySDC.core.Hooks import hooks


class approx_solution_hook(hooks):
class LogGlobalErrorPostStepDifferentialVariable(hooks):
"""
Hook class to add the approximate solution to the output generated by the sweeper after each time step
Hook class to log the error to the output generated by the sweeper after
each time step.
"""

def __init__(self):
"""
Initialization routine for the custom hook
"""
super(approx_solution_hook, self).__init__()

def post_step(self, step, level_number):
"""
Default routine called after each step
Args:
step: the current step
level_number: the current level number
r"""
Default routine called after each step.
Parameters
----------
step : pySDC.core.Step
Current step.
level_number : pySDC.core.level
Current level number.
"""

super(approx_solution_hook, self).post_step(step, level_number)
super().post_step(step, level_number)

# some abbreviations
L = step.levels[level_number]
P = L.prob

# TODO: is it really necessary to recompute the end point? Hasn't this been done already?
L.sweep.compute_end_point()

# compute and save errors
# Note that the component from which the error is measured is specified here
upde = P.u_exact(step.time + step.dt)
e_global_differential = abs(upde.diff - L.uend.diff)

self.add_to_stats(
process=step.status.slot,
time=L.time + L.dt,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='approx_solution',
value=L.uend,
type='e_global_differential_post_step',
value=e_global_differential,
)


class error_hook(hooks):
class LogGlobalErrorPostStepAlgebraicVariable(hooks):
"""
Hook class to add the approximate solution to the output generated by the sweeper after each time step
Logs the global error in the algebraic variable
"""

def __init__(self):
"""
Initialization routine for the custom hook
"""
super(error_hook, self).__init__()

def post_step(self, step, level_number):
"""
Default routine called after each step
Args:
step: the current step
level_number: the current level number
r"""
Default routine called after each step.
Parameters
----------
step : pySDC.core.Step
Current step.
level_number : pySDC.core.level
Current level number.
"""

super(error_hook, self).post_step(step, level_number)
super().post_step(step, level_number)

# some abbreviations
L = step.levels[level_number]
P = L.prob

# TODO: is it really necessary to recompute the end point? Hasn't this been done already?
L.sweep.compute_end_point()

# compute and save errors
# Note that the component from which the error is measured is specified here
upde = P.u_exact(step.time + step.dt)
err = abs(upde[0] - L.uend[0])
# err = abs(upde[4] - L.uend[4])
e_global_algebraic = abs(upde.alg - L.uend.alg)

self.add_to_stats(
process=step.status.slot,
time=L.time + L.dt,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='error_post_step',
value=err,
type='e_global_algebraic_post_step',
value=e_global_algebraic,
)
16 changes: 10 additions & 6 deletions pySDC/projects/DAE/misc/ProblemDAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from scipy.optimize import root

from pySDC.core.Problem import ptype, WorkCounter
from pySDC.implementations.datatype_classes.mesh import mesh
from pySDC.projects.DAE.misc.DAEMesh import DAEMesh


class ptype_dae(ptype):
Expand All @@ -25,8 +25,8 @@ class ptype_dae(ptype):
in work_counters['rhs']
"""

dtype_u = mesh
dtype_f = mesh
dtype_u = DAEMesh
dtype_f = DAEMesh

def __init__(self, nvars, newton_tol):
"""Initialization routine"""
Expand Down Expand Up @@ -54,14 +54,18 @@ def solve_system(self, impl_sys, u0, t):
me : dtype_u
Numerical solution.
"""

me = self.dtype_u(self.init)

def implSysFlatten(unknowns, **kwargs):
sys = impl_sys(unknowns.reshape(me.shape).view(type(u0)), **kwargs)
return sys.flatten()

opt = root(
impl_sys,
implSysFlatten,
u0,
method='hybr',
tol=self.newton_tol,
)
me[:] = opt.x
me[:] = opt.x.reshape(me.shape)
self.work_counters['newton'].niter += opt.nfev
return me
36 changes: 17 additions & 19 deletions pySDC/projects/DAE/problems/DiscontinuousTestDAE.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from pySDC.core.Problem import WorkCounter
from pySDC.projects.DAE.misc.ProblemDAE import ptype_dae


Expand Down Expand Up @@ -57,14 +58,12 @@ class DiscontinuousTestDAE(ptype_dae):

def __init__(self, newton_tol=1e-12):
"""Initialization routine"""
nvars = 2
super().__init__(nvars, newton_tol)
self._makeAttributeAndRegister('nvars', localVars=locals(), readOnly=True)
self._makeAttributeAndRegister('newton_tol', localVars=locals())
super().__init__(nvars=2, newton_tol=newton_tol)

self.t_switch_exact = np.arccosh(50)
self.t_switch = None
self.nswitches = 0
self.work_counters['rhs'] = WorkCounter()

def eval_f(self, u, du, t):
r"""
Expand All @@ -85,24 +84,21 @@ def eval_f(self, u, du, t):
The right-hand side of f (contains two components).
"""

y, z = u[0], u[1]
dy = du[0]
y, z = u.diff[0], u.alg[0]
dy = du.diff[0]

t_switch = np.inf if self.t_switch is None else self.t_switch

h = 2 * y - 100
f = self.dtype_f(self.init)

if h >= 0 or t >= t_switch:
f[:] = (
dy,
y**2 - z**2 - 1,
)
f.diff[0] = dy
f.alg[0] = y**2 - z**2 - 1
else:
f[:] = (
dy - z,
y**2 - z**2 - 1,
)
f.diff[0] = dy - z
f.alg[0] = y**2 - z**2 - 1
self.work_counters['rhs']()
return f

def u_exact(self, t, **kwargs):
Expand All @@ -125,9 +121,11 @@ def u_exact(self, t, **kwargs):

me = self.dtype_u(self.init)
if t <= self.t_switch_exact:
me[:] = (np.cosh(t), np.sinh(t))
me.diff[0] = np.cosh(t)
me.alg[0] = np.sinh(t)
else:
me[:] = (np.cosh(self.t_switch_exact), np.sinh(self.t_switch_exact))
me.diff[0] = np.cosh(self.t_switch_exact)
me.alg[0] = np.sinh(self.t_switch_exact)
return me

def get_switching_info(self, u, t):
Expand Down Expand Up @@ -162,14 +160,14 @@ def get_switching_info(self, u, t):
m_guess = -100

for m in range(1, len(u)):
h_prev_node = 2 * u[m - 1][0] - 100
h_curr_node = 2 * u[m][0] - 100
h_prev_node = 2 * u[m - 1].diff[0] - 100
h_curr_node = 2 * u[m].diff[0] - 100
if h_prev_node < 0 and h_curr_node >= 0:
switch_detected = True
m_guess = m - 1
break

state_function = [2 * u[m][0] - 100 for m in range(len(u))]
state_function = [2 * u[m].diff[0] - 100 for m in range(len(u))]
return switch_detected, m_guess, state_function

def count_switches(self):
Expand Down
Loading

0 comments on commit b059c62

Please sign in to comment.