Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tentative solution to replace FrozenClass #244

Merged
merged 21 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pySDC/core/Collocation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import logging

import numpy as np
import scipy.interpolate as intpl

from pySDC.core.Nodes import NodesGenerator
from pySDC.core.Errors import CollocationError
Expand Down
11 changes: 5 additions & 6 deletions pySDC/core/Controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,11 @@ def dump_setup(self, step, controller_params, description):
else:
out += ' %s = %s\n' % (k, v)
out += '--> Problem: %s\n' % L.prob.__class__
for k, v in vars(L.prob.params).items():
if not k.startswith('_'):
if k in description['problem_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
for k, v in L.prob.params.items():
if k in description['problem_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += '--> Data type u: %s\n' % L.prob.dtype_u
out += '--> Data type f: %s\n' % L.prob.dtype_f
out += '--> Sweeper: %s\n' % L.sweep.__class__
Expand Down
24 changes: 11 additions & 13 deletions pySDC/core/Problem.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import logging

from pySDC.helpers.pysdc_helper import FrozenClass

# parent class that register some class attributes in a list of paramters
class RegisterParams(object):
def _register(self, *parNames):
if not hasattr(self, '_parNames'):
self._parNames = []
self._parNames += parNames

# short helper class to add params as attributes
class _Pars(FrozenClass):
def __init__(self, pars):
@property
def params(self):
return {name: getattr(self, name) for name in self._parNames}

for k, v in pars.items():
setattr(self, k, v)

self._freeze()


class ptype(object):
class ptype(RegisterParams):
"""
Prototype class for problems, just defines the attributes essential to get started

Expand All @@ -25,7 +25,7 @@ class ptype(object):
dtype_f: RHS data type
"""

def __init__(self, init, dtype_u, dtype_f, **kwargs):
def __init__(self, init, dtype_u, dtype_f):
"""
Initialization routine.
Add the problem parameters as keyword arguments.
Expand All @@ -35,8 +35,6 @@ def __init__(self, init, dtype_u, dtype_f, **kwargs):
dtype_u: variable data type
dtype_f: RHS data type
"""
self.params = _Pars(kwargs)

# set up logger
self.logger = logging.getLogger('problem')

Expand Down
207 changes: 115 additions & 92 deletions pySDC/implementations/problem_classes/AdvectionEquation_ND_FD.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import scipy.sparse as sp
from scipy.sparse.linalg import gmres, spsolve

from pySDC.core.Errors import ParameterError, ProblemError
from pySDC.core.Errors import ProblemError
from pySDC.core.Problem import ptype
from pySDC.helpers import problem_helper
from pySDC.implementations.datatype_classes.mesh import mesh
Expand Down Expand Up @@ -30,13 +30,13 @@ def __init__(
liniter=10000,
direct_solver=True,
bc='periodic',
ndim=None,
sigma=6e-2,
):
"""
Initialization routine

Args can be set as values or as tuples, which will increase the dimension. Do, however, take care that all
spatial parameters have the same dimension.
Args can be set as values or as tuples, which will increase the dimension.
Do, however, take care that all spatial parameters have the same dimension.

Args:
nvars (int): Spatial resolution, can be tuple
Expand All @@ -48,91 +48,108 @@ def __init__(
liniter (int): Max. iterations for GMRES
direct_solver (bool): Whether to solve directly or use GMRES
bc (str): Boundary conditions
ndim (int): Number of dimensions. Is set automatically if left at None.
"""

# make sure parameters have the correct form
if not (type(nvars) is tuple and type(freq) is tuple) and not (type(nvars) is int and type(freq) is int):
print(nvars, freq)
raise ProblemError('Type of nvars and freq must be both either int or both tuple')
# make sure parameters have the correct types
if not type(nvars) in [int, tuple]:
raise ProblemError('nvars should be either tuple or int')
if not type(freq) in [int, tuple]:
raise ProblemError('freq should be either tuple or int')

if ndim is None:
if type(nvars) is int:
ndim = 1
elif type(nvars) is tuple:
ndim = len(nvars)
# transforms nvars into a tuple
if type(nvars) is int:
nvars = (nvars,)

# automatically determine ndim from nvars
ndim = len(nvars)
if ndim > 3:
raise ProblemError(f'can work with up to three dimensions, got {ndim}')

if type(freq) is tuple:
for f in freq:
if f % 2 != 0 and bc == 'periodic':
raise ProblemError('need even number of frequencies due to periodic BCs')
else:
if freq % 2 != 0 and freq != -1 and bc == 'periodic':
tlunet marked this conversation as resolved.
Show resolved Hide resolved
# eventually extend freq to other dimension
if type(freq) is int:
freq = (freq,) * ndim
if len(freq) != ndim:
raise ProblemError(f'len(freq)={len(freq)}, different to ndim={ndim}')

# check values for freq and nvars
for f in freq:
if ndim == 1 and f == -1:
# use Gaussian initial solution in 1D
bc == 'periodic'
break
if f % 2 != 0 and bc == 'periodic':
raise ProblemError('need even number of frequencies due to periodic BCs')

if type(nvars) is tuple:
for nvar in nvars:
if nvar % 2 != 0 and bc == 'periodic':
raise ProblemError('the setup requires nvars = 2^p per dimension')
if (nvar + 1) % 2 != 0 and bc == 'dirichlet-zero':
raise ProblemError('setup requires nvars = 2^p - 1')
if nvars[1:] != nvars[:-1]:
raise ProblemError('need a square domain, got %s' % nvars)
else:
if nvars % 2 != 0 and bc == 'periodic':
for nvar in nvars:
if nvar % 2 != 0 and bc == 'periodic':
raise ProblemError('the setup requires nvars = 2^p per dimension')
if (nvars + 1) % 2 != 0 and bc == 'dirichlet-zero':
if (nvar + 1) % 2 != 0 and bc == 'dirichlet-zero':
raise ProblemError('setup requires nvars = 2^p - 1')
if ndim > 1 and nvars[1:] != nvars[:-1]:
raise ProblemError('need a square domain, got %s' % nvars)

# invoke super init, passing number of dofs, dtype_u and dtype_f
super(advectionNd, self).__init__(
super().__init__(
init=(nvars, None, np.dtype('float64')),
dtype_u=mesh,
dtype_f=mesh,
nvars=nvars,
c=c,
freq=freq,
stencil_type=stencil_type,
order=order,
lintol=lintol,
liniter=liniter,
direct_solver=direct_solver,
bc=bc,
ndim=ndim,
)

if self.params.ndim == 1:
if type(self.params.nvars) is not tuple:
self.params.nvars = (self.params.nvars,)
if type(self.params.freq) is not tuple:
self.params.freq = (self.params.freq,)

# compute dx (equal in both dimensions) and get discretization matrix A
if self.params.bc == 'periodic':
self.dx = 1.0 / self.params.nvars[0]
xvalues = np.array([i * self.dx for i in range(self.params.nvars[0])])
elif self.params.bc == 'dirichlet-zero':
self.dx = 1.0 / (self.params.nvars[0] + 1)
xvalues = np.array([(i + 1) * self.dx for i in range(self.params.nvars[0])])
if bc == 'periodic':
xvalues = np.linspace(0, 1, num=nvars[0], endpoint=False)
elif bc == 'dirichlet-zero':
xvalues = np.linspace(0, 1, num=nvars[0] + 2)[1:-1]
else:
raise ProblemError(f'Boundary conditions {self.params.bc} not implemented.')
dx = xvalues[1] - xvalues[0]

self.A = problem_helper.get_finite_difference_matrix(
derivative=1,
order=self.params.order,
stencil_type=self.params.stencil_type,
dx=self.dx,
size=self.params.nvars[0],
dim=self.params.ndim,
bc=self.params.bc,
order=order,
stencil_type=stencil_type,
dx=dx,
size=nvars[0],
dim=ndim,
bc=bc,
)
self.A *= -self.params.c
self.A *= -c

self.xvalues = xvalues
self.Id = sp.eye(np.prod(nvars), format='csc')

# store relevant attributes
self.freq, self.sigma = freq, sigma
self.lintol, self.liniter, self.direct_solve = lintol, liniter, direct_solver

# read-only attributes
self._readOnly = [ndim, c, stencil_type, order]
tlunet marked this conversation as resolved.
Show resolved Hide resolved

# register parameters
self._register('nvars', 'c', 'freq', 'stencil_type', 'order', 'lintol', 'liniter', 'direct_solver', 'bc')
tlunet marked this conversation as resolved.
Show resolved Hide resolved

self.xv = np.meshgrid(*[xvalues for _ in range(self.params.ndim)])
self.Id = sp.eye(np.prod(self.params.nvars), format='csc')
@property
def ndim(self):
return self._readOnly[0]

@property
def c(self):
return self._readOnly[1]

@property
def stencil_type(self):
return self._readOnly[2]

@property
def order(self):
return self._readOnly[3]

@property
def nvars(self):
return (self.xvalues.size,) * self.ndim

@property
def bc(self):
return 'periodic' if self.xvalue[0] == 0 else 'dirichlet-zero'

def eval_f(self, u, t):
"""
Expand All @@ -147,7 +164,7 @@ def eval_f(self, u, t):
"""

f = self.dtype_f(self.init)
f[:] = self.A.dot(u.flatten()).reshape(self.params.nvars)
f[:] = self.A.dot(u.flatten()).reshape(self.nvars)
return f

def solve_system(self, rhs, factor, u0, t):
Expand All @@ -163,20 +180,23 @@ def solve_system(self, rhs, factor, u0, t):
Returns:
dtype_u: solution as mesh
"""

direct_solver, Id, A, nvars, lintol, liniter = (
self.direct_solver,
self.Id,
self.A,
self.nvars,
self.lintol,
self.liniter,
)
me = self.dtype_u(self.init)

if self.params.direct_solver:
me[:] = spsolve(self.Id - factor * self.A, rhs.flatten()).reshape(self.params.nvars)
if direct_solver:
me[:] = spsolve(Id - factor * A, rhs.flatten()).reshape(nvars)
else:
me[:] = gmres(
self.Id - factor * self.A,
rhs.flatten(),
x0=u0.flatten(),
tol=self.params.lintol,
maxiter=self.params.liniter,
atol=0,
)[0].reshape(self.params.nvars)
me[:] = gmres(Id - factor * A, rhs.flatten(), x0=u0.flatten(), tol=lintol, maxiter=liniter, atol=0,)[
0
].reshape(nvars)

return me

def u_exact(self, t, **kwargs):
Expand All @@ -187,24 +207,27 @@ def u_exact(self, t, **kwargs):
t (float): current time

Returns:
dtype_u: exact solution
me: exact solution
tlunet marked this conversation as resolved.
Show resolved Hide resolved
"""

# Initialize pointers and variables
ndim, freq, x, c, sigma = self.ndim, self.freq, self.xvalues, self.c, self.sigma
me = self.dtype_u(self.init)
if self.params.ndim == 1:
if self.params.freq[0] >= 0:
me[:] = np.sin(np.pi * self.params.freq[0] * (self.xv[0] - self.params.c * t))
elif self.params.freq[0] == -1:
me[:] = np.exp(-0.5 * (((self.xv[0] - (self.params.c * t)) % 1.0 - 0.5) / self.params.sigma) ** 2)

elif self.params.ndim == 2:
me[:] = np.sin(np.pi * self.params.freq[0] * (self.xv[0] - self.params.c * t)) * np.sin(
np.pi * self.params.freq[1] * (self.xv[1] - self.params.c * t)
)
elif self.params.ndim == 3:

if ndim == 1:
if freq[0] >= 0:
me[:] = np.sin(np.pi * freq[0] * (x - c * t))
elif freq[0] == -1:
# Gaussian initial solution
me[:] = np.exp(-0.5 * (((x - (c * t)) % 1.0 - 0.5) / sigma) ** 2)

elif ndim == 2:
me[:] = np.sin(np.pi * freq[0] * (x[None, :] - c * t)) * np.sin(np.pi * freq[1] * (x[:, None] - c * t))

elif ndim == 3:
me[:] = (
np.sin(np.pi * self.params.freq[0] * (self.xv[0] - self.params.c * t))
* np.sin(np.pi * self.params.freq[1] * (self.xv[1] - self.params.c * t))
* np.sin(np.pi * self.params.freq[2] * (self.xv[2] - self.params.c * t))
np.sin(np.pi * freq[0] * (x[None, :, None] - c * t))
* np.sin(np.pi * freq[1] * (x[:, None, None] - c * t))
* np.sin(np.pi * freq[2] * (x[None, None, :] - c * t))
)

return me