Skip to content

Commit

Permalink
Generic MPI FFT class (#408)
Browse files Browse the repository at this point in the history
* Added generic MPIFFT problem class

* Fixes

* Generalized to `xp` in preparation for GPUs

* Fixes

* Ported Allen-Cahn to generic MPI FFT implementation
  • Loading branch information
brownbaerchen authored Apr 1, 2024
1 parent 3036351 commit ea1ed48
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 320 deletions.
122 changes: 19 additions & 103 deletions pySDC/implementations/problem_classes/AllenCahn_MPIFFT.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import numpy as np
from mpi4py import MPI
from mpi4py_fft import PFFT

from pySDC.core.Errors import ProblemError
from pySDC.core.Problem import ptype
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh

from pySDC.implementations.problem_classes.generic_MPIFFT_Laplacian import IMEX_Laplacian_MPIFFT
from mpi4py_fft import newDistArray


class allencahn_imex(ptype):
class allencahn_imex(IMEX_Laplacian_MPIFFT):
r"""
Example implementing the :math:`N`-dimensional Allen-Cahn equation with periodic boundary conditions :math:`u \in [0, 1]^2`
Expand Down Expand Up @@ -64,68 +60,21 @@ class allencahn_imex(ptype):
.. [1] https://mpi4py-fft.readthedocs.io/en/latest/
"""

dtype_u = mesh
dtype_f = imex_mesh

def __init__(
self,
nvars=None,
eps=0.04,
radius=0.25,
spectral=None,
dw=0.0,
L=1.0,
init_type='circle',
comm=MPI.COMM_WORLD,
**kwargs,
):
"""Initialization routine"""

if nvars is None:
nvars = (128, 128)

if not (isinstance(nvars, tuple) and len(nvars) > 1):
raise ProblemError('Need at least two dimensions')

# Creating FFT structure
ndim = len(nvars)
axes = tuple(range(ndim))
self.fft = PFFT(comm, list(nvars), axes=axes, dtype=np.float64, collapse=True)

# get test data to figure out type and dimensions
tmp_u = newDistArray(self.fft, spectral)

# invoke super init, passing the communicator and the local dimensions as init
super().__init__(init=(tmp_u.shape, comm, tmp_u.dtype))
self._makeAttributeAndRegister(
'nvars', 'eps', 'radius', 'spectral', 'dw', 'L', 'init_type', 'comm', localVars=locals(), readOnly=True
)

L = np.array([self.L] * ndim, dtype=float)

# get local mesh
X = np.ogrid[self.fft.local_slice(False)]
N = self.fft.global_shape()
for i in range(len(N)):
X[i] = X[i] * L[i] / N[i]
self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

# get local wavenumbers and Laplace operator
s = self.fft.local_slice()
N = self.fft.global_shape()
k = [np.fft.fftfreq(n, 1.0 / n).astype(int) for n in N[:-1]]
k.append(np.fft.rfftfreq(N[-1], 1.0 / N[-1]).astype(int))
K = [ki[si] for ki, si in zip(k, s)]
Ks = np.meshgrid(*K, indexing='ij', sparse=True)
Lp = 2 * np.pi / L
for i in range(ndim):
Ks[i] = (Ks[i] * Lp[i]).astype(float)
K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
K = np.array(K).astype(float)
self.K2 = np.sum(K * K, 0, dtype=float)

# Need this for diagnostics
self.dx = self.L / nvars[0]
self.dy = self.L / nvars[1]
kwargs['L'] = kwargs.get('L', 1.0)
super().__init__(alpha=1.0, dtype=np.dtype('float'), **kwargs)
self._makeAttributeAndRegister('eps', 'radius', 'dw', 'init_type', localVars=locals(), readOnly=True)

def _eval_explicit_part(self, u, t, f_expl):
f_expl[:] = -2.0 / self.eps**2 * u * (1.0 - u) * (1.0 - 2.0 * u) - 6.0 * self.dw * u * (1.0 - u)
return f_expl

def eval_f(self, u, t):
"""
Expand All @@ -146,56 +95,24 @@ def eval_f(self, u, t):

f = self.dtype_f(self.init)

f.impl[:] = self._eval_Laplacian(u, f.impl)

if self.spectral:
f.impl = -self.K2 * u

if self.eps > 0:
tmp = self.fft.backward(u)
tmpf = -2.0 / self.eps**2 * tmp * (1.0 - tmp) * (1.0 - 2.0 * tmp) - 6.0 * self.dw * tmp * (1.0 - tmp)
f.expl[:] = self.fft.forward(tmpf)
tmp[:] = self._eval_explicit_part(tmp, t, tmp)
f.expl[:] = self.fft.forward(tmp)

else:
u_hat = self.fft.forward(u)
lap_u_hat = -self.K2 * u_hat
f.impl[:] = self.fft.backward(lap_u_hat, f.impl)

if self.eps > 0:
f.expl = -2.0 / self.eps**2 * u * (1.0 - u) * (1.0 - 2.0 * u) - 6.0 * self.dw * u * (1.0 - u)
f.expl[:] = self._eval_explicit_part(u, t, f.expl)

self.work_counters['rhs']()
return f

def solve_system(self, rhs, factor, u0, t):
"""
Simple FFT solver for the diffusion part.
Parameters
----------
rhs : dtype_f
Right-hand side for the linear system.
factor : float
Abbrev. for the node-to-node stepsize (or any other factor required).
u0 : dtype_u
Initial guess for the iterative solver (not used here so far).
t : float
Current time (e.g. for time-dependent BCs).
Returns
-------
me : dtype_u
The solution as mesh.
"""

if self.spectral:
me = rhs / (1.0 + factor * self.K2)

else:
me = self.dtype_u(self.init)
rhs_hat = self.fft.forward(rhs)
rhs_hat /= 1.0 + factor * self.K2
me[:] = self.fft.backward(rhs_hat)

return me

def u_exact(self, t):
r"""
Routine to compute the exact solution at time :math:`t`.
Expand Down Expand Up @@ -289,8 +206,9 @@ def eval_f(self, u, t):

f = self.dtype_f(self.init)

f.impl[:] = self._eval_Laplacian(u, f.impl)

if self.spectral:
f.impl = -self.K2 * u

tmp = newDistArray(self.fft, False)
tmp[:] = self.fft.backward(u, tmp)
Expand Down Expand Up @@ -324,9 +242,6 @@ def eval_f(self, u, t):
f.expl[:] = self.fft.forward(tmpf)

else:
u_hat = self.fft.forward(u)
lap_u_hat = -self.K2 * u_hat
f.impl[:] = self.fft.backward(lap_u_hat, f.impl)

if self.eps > 0:
f.expl = -2.0 / self.eps**2 * u * (1.0 - u) * (1.0 - 2.0 * u)
Expand All @@ -353,4 +268,5 @@ def eval_f(self, u, t):

f.expl -= 6.0 * dw * u * (1.0 - u)

self.work_counters['rhs']()
return f
106 changes: 24 additions & 82 deletions pySDC/implementations/problem_classes/Brusselator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import numpy as np
from mpi4py import MPI
from mpi4py_fft import PFFT

from pySDC.core.Errors import ProblemError
from pySDC.core.Problem import ptype, WorkCounter
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
from pySDC.implementations.problem_classes.generic_MPIFFT_Laplacian import IMEX_Laplacian_MPIFFT

from mpi4py_fft import newDistArray


class Brusselator(ptype):
class Brusselator(IMEX_Laplacian_MPIFFT):
r"""
Two-dimensional Brusselator from [1]_.
This is a reaction-diffusion equation with non-autonomous source term:
Expand All @@ -27,68 +22,29 @@ class Brusselator(ptype):
.. [1] https://link.springer.com/book/10.1007/978-3-642-05221-7
"""

dtype_u = mesh
dtype_f = imex_mesh

def __init__(self, nvars=None, alpha=0.1, comm=MPI.COMM_WORLD):
def __init__(self, alpha=0.1, **kwargs):
"""Initialization routine"""
nvars = (128,) * 2 if nvars is None else nvars
L = 1.0

if not (isinstance(nvars, tuple) and len(nvars) > 1):
raise ProblemError('Need at least two dimensions')

# Create FFT structure
self.ndim = len(nvars)
axes = tuple(range(self.ndim))
self.fft = PFFT(
comm,
list(nvars),
axes=axes,
dtype=np.float64,
collapse=True,
backend='fftw',
)

# get test data to figure out type and dimensions
tmp_u = newDistArray(self.fft, False)
super().__init__(spectral=False, L=1.0, dtype='d', alpha=alpha, **kwargs)

# prepare the array with two components
shape = (2,) + tmp_u.shape
shape = (2,) + (self.init[0])
self.iU = 0
self.iV = 1
self.init = (shape, self.comm, np.dtype('float'))

def _eval_explicit_part(self, u, t, f_expl):
iU, iV = self.iU, self.iV
x, y = self.X[0], self.X[1]

super().__init__(init=(shape, comm, tmp_u.dtype))
self._makeAttributeAndRegister('nvars', 'alpha', 'L', 'comm', localVars=locals(), readOnly=True)

L = np.array([self.L] * self.ndim, dtype=float)

# get local mesh for distributed FFT
X = np.ogrid[self.fft.local_slice(False)]
N = self.fft.global_shape()
for i in range(len(N)):
X[i] = X[i] * L[i] / N[i]
self.X = [np.broadcast_to(x, self.fft.shape(False)) for x in X]

# get local wavenumbers and Laplace operator
s = self.fft.local_slice()
N = self.fft.global_shape()
k = [np.fft.fftfreq(n, 1.0 / n).astype(int) for n in N[:-1]]
k.append(np.fft.rfftfreq(N[-1], 1.0 / N[-1]).astype(int))
K = [ki[si] for ki, si in zip(k, s)]
Ks = np.meshgrid(*K, indexing='ij', sparse=True)
Lp = 2 * np.pi / L
for i in range(self.ndim):
Ks[i] = (Ks[i] * Lp[i]).astype(float)
K = [np.broadcast_to(k, self.fft.shape(True)) for k in Ks]
K = np.array(K).astype(float)
self.K2 = np.sum(K * K, 0, dtype=float)

# Need this for diagnostics
self.dx = self.L / nvars[0]
self.dy = self.L / nvars[1]

self.work_counters['rhs'] = WorkCounter()
# evaluate time independent part
f_expl[iU, ...] = 1.0 + u[iU] ** 2 * u[iV] - 4.4 * u[iU]
f_expl[iV, ...] = 3.4 * u[iU] - u[iU] ** 2 * u[iV]

# add time-dependent part
if t >= 1.1:
mask = (x - 0.3) ** 2 + (y - 0.6) ** 2 <= 0.1**2
f_expl[iU][mask] += 5.0
return f_expl

def eval_f(self, u, t):
"""
Expand All @@ -106,25 +62,13 @@ def eval_f(self, u, t):
f : dtype_f
The right-hand side of the problem.
"""
iU, iV = self.iU, self.iV
x, y = self.X[0], self.X[1]

f = self.dtype_f(self.init)

# evaluate Laplacian to be solved implicitly
for i in [self.iU, self.iV]:
u_hat = self.fft.forward(u[i, ...])
lap_u_hat = -self.alpha * self.K2 * u_hat
f.impl[i, ...] = self.fft.backward(lap_u_hat, f.impl[i, ...])
f.impl[i, ...] = self._eval_Laplacian(u[i], f.impl[i])

# evaluate time independent part
f.expl[iU, ...] = 1.0 + u[iU] ** 2 * u[iV] - 4.4 * u[iU]
f.expl[iV, ...] = 3.4 * u[iU] - u[iU] ** 2 * u[iV]

# add time-dependent part
if t >= 1.1:
mask = (x - 0.3) ** 2 + (y - 0.6) ** 2 <= 0.1**2
f.expl[iU][mask] += 5.0
f.expl[:] = self._eval_explicit_part(u, t, f.expl)

self.work_counters['rhs']()

Expand Down Expand Up @@ -153,9 +97,7 @@ def solve_system(self, rhs, factor, u0, t):
me = self.dtype_u(self.init)

for i in [self.iU, self.iV]:
rhs_hat = self.fft.forward(rhs[i, ...])
rhs_hat /= 1.0 + factor * self.K2 * self.alpha
me[i, ...] = self.fft.backward(rhs_hat, me[i, ...])
me[i, ...] = self._invert_Laplacian(me[i], factor, rhs[i])

return me

Expand Down Expand Up @@ -184,8 +126,8 @@ def u_exact(self, t, u_init=None, t_init=None):
me = self.dtype_u(self.init, val=0.0)

if t == 0:
me[iU, ...] = 22.0 * y * (1 - y / self.L) ** (3.0 / 2.0) / self.L
me[iV, ...] = 27.0 * x * (1 - x / self.L) ** (3.0 / 2.0) / self.L
me[iU, ...] = 22.0 * y * (1 - y / self.L[0]) ** (3.0 / 2.0) / self.L[0]
me[iV, ...] = 27.0 * x * (1 - x / self.L[0]) ** (3.0 / 2.0) / self.L[0]
else:

def eval_rhs(t, u):
Expand Down
Loading

0 comments on commit ea1ed48

Please sign in to comment.