Skip to content

Commit

Permalink
Merge branch 'GS_WP' into neuralpint
Browse files Browse the repository at this point in the history
  • Loading branch information
tlunet committed Dec 13, 2024
2 parents 9d3ef1d + 28d6715 commit 3b49d5f
Show file tree
Hide file tree
Showing 57 changed files with 4,387 additions and 924 deletions.
3 changes: 2 additions & 1 deletion pySDC/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pySDC.helpers.pysdc_helper import FrozenClass
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
from pySDC.implementations.hooks.default_hook import DefaultHooks
from pySDC.implementations.hooks.log_timings import CPUTimings


# short helper class to add params as attributes
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(self, controller_params, description, useMPI=None):

# check if we have a hook on this list. If not, use default class.
self.__hooks = []
hook_classes = [DefaultHooks]
hook_classes = [DefaultHooks, CPUTimings]
user_hooks = controller_params.get('hook_class', [])
hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
[self.add_hook(hook) for hook in hook_classes]
Expand Down
13 changes: 13 additions & 0 deletions pySDC/core/convergence_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,19 @@ def post_step_processing(self, controller, S, **kwargs):
"""
pass

def post_run_processing(self, controller, S, **kwargs):
"""
Do whatever you want to after the run here.
Args:
controller (pySDC.Controller): The controller
S (pySDC.Step): The current step
Returns:
None
"""
pass

def prepare_next_block(self, controller, S, size, time, Tend, **kwargs):
"""
Prepare stuff like spreading step sizes or whatever.
Expand Down
26 changes: 25 additions & 1 deletion pySDC/helpers/NCCL_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __getattr__(self, name):
Args:
Name (str): Name of the requested attribute
"""
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']:
cp.cuda.get_current_stream().synchronize()

return getattr(self.commMPI, name)
Expand Down Expand Up @@ -71,6 +71,26 @@ def get_op(self, MPI_op):
else:
raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')

def reduce(self, sendobj, op=MPI.SUM, root=0):
sync = False
if hasattr(sendobj, 'data'):
if hasattr(sendobj.data, 'ptr'):
sync = True
if sync:
cp.cuda.Device().synchronize()

return self.commMPI.reduce(sendobj, op=op, root=root)

def allreduce(self, sendobj, op=MPI.SUM):
sync = False
if hasattr(sendobj, 'data'):
if hasattr(sendobj.data, 'ptr'):
sync = True
if sync:
cp.cuda.Device().synchronize()

return self.commMPI.allreduce(sendobj, op=op)

def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
if not hasattr(sendbuf.data, 'ptr'):
return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)
Expand Down Expand Up @@ -113,3 +133,7 @@ def Bcast(self, buf, root=0):
stream = cp.cuda.get_current_stream()

self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)

def Barrier(self):
cp.cuda.get_current_stream().synchronize()
self.commMPI.Barrier()
2 changes: 2 additions & 0 deletions pySDC/helpers/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def figsize_by_journal(journal, scale, ratio): # pragma: no cover
textwidths = {
'JSC_beamer': 426.79135,
'Springer_Numerical_Algorithms': 338.58778,
'Springer_proceedings': 347.12354,
'JSC_thesis': 434.26027,
'TUHH_thesis': 426.79135,
}
Expand All @@ -50,6 +51,7 @@ def figsize_by_journal(journal, scale, ratio): # pragma: no cover
'JSC_beamer': 214.43411,
'JSC_thesis': 635.5,
'TUHH_thesis': 631.65118,
'Springer_proceedings': 549.13828,
}
assert (
journal in textwidths.keys()
Expand Down
23 changes: 17 additions & 6 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ def __init__(self, comm=None, useGPU=False, debug=False):
self.BCs = None

self.fft_cache = {}
self.fft_dealias_shape_cache = {}

@property
def u_init(self):
Expand Down Expand Up @@ -1470,8 +1471,13 @@ def _transform_dct(self, u, axes, padding=None, **kwargs):

if padding is not None:
shape = list(v.shape)
if self.comm:
shape[0] = self.comm.allreduce(v.shape[0])
if ('forward', *padding) in self.fft_dealias_shape_cache.keys():
shape[0] = self.fft_dealias_shape_cache[('forward', *padding)]
elif self.comm:
send_buf = np.array(v.shape[0])
recv_buf = np.array(v.shape[0])
self.comm.Allreduce(send_buf, recv_buf)
shape[0] = int(recv_buf)
fft = self.get_fft(axes, 'forward', shape=shape)
else:
fft = self.get_fft(axes, 'forward', **kwargs)
Expand Down Expand Up @@ -1642,8 +1648,13 @@ def _transform_idct(self, u, axes, padding=None, **kwargs):
if padding is not None:
if padding[axis] != 1:
shape = list(v.shape)
if self.comm:
shape[0] = self.comm.allreduce(v.shape[0])
if ('backward', *padding) in self.fft_dealias_shape_cache.keys():
shape[0] = self.fft_dealias_shape_cache[('backward', *padding)]
elif self.comm:
send_buf = np.array(v.shape[0])
recv_buf = np.array(v.shape[0])
self.comm.Allreduce(send_buf, recv_buf)
shape[0] = int(recv_buf)
ifft = self.get_fft(axes, 'backward', shape=shape)
else:
ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
Expand Down Expand Up @@ -1748,8 +1759,6 @@ def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
if self.comm.size == 1:
return u.copy()

fft = self.get_fft(**kwargs) if fft is None else fft

global_fft = self.get_fft(**kwargs)
axisA = [me.axisA for me in global_fft.transfer]
axisB = [me.axisB for me in global_fft.transfer]
Expand Down Expand Up @@ -1787,6 +1796,8 @@ def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
else: # go the potentially slower route of not reusing transfer classes
from mpi4py_fft import newDistArray

fft = self.get_fft(**kwargs) if fft is None else fft

_in = newDistArray(fft, forward).redistribute(axis_in)
_in[...] = u

Expand Down
3 changes: 3 additions & 0 deletions pySDC/implementations/controller_classes/controller_MPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def run(self, u0, t0, Tend):
for hook in self.hooks:
hook.post_run(step=self.S, level_number=0)

for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
C.post_run_processing(self, self.S, comm=self.comm)

comm_active.Free()

return uend, self.return_stats()
Expand Down
4 changes: 4 additions & 0 deletions pySDC/implementations/controller_classes/controller_nonMPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def run(self, u0, t0, Tend):
for hook in self.hooks:
hook.post_run(step=S, level_number=0)

for S in self.MS:
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
C.post_run_processing(self, S, MS=MS_active)

return uend, self.return_stats()

def restart_block(self, active_slots, time, u0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,16 @@ def determine_restart(self, controller, S, **kwargs):
if self.get_convergence(controller, S, **kwargs):
self.res_last_iter = np.inf

if self.params.restart_at_maxiter and S.levels[0].status.residual > S.levels[0].params.restol:
L = S.levels[0]
e_tol_converged = (
L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
)

if (
self.params.restart_at_maxiter
and S.levels[0].status.residual > S.levels[0].params.restol
and not e_tol_converged
):
self.trigger_restart_upon_nonconvergence(S)
elif self.get_local_error_estimate(controller, S, **kwargs) > self.params.e_tol:
S.status.restart = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def check_convergence(S, self=None):
iter_converged = S.status.iter >= S.params.maxiter
res_converged = L.status.residual <= L.params.restol
e_tol_converged = (
L.status.error_embedded_estimate < L.params.e_tol
if (L.params.get('e_tol') and L.status.get('error_embedded_estimate'))
else False
L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
)
converged = (
iter_converged or res_converged or e_tol_converged or S.status.force_done
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def estimate_embedded_error_serial(self, L):
dtype_u: The embedded error estimate
"""
if self.params.sweeper_type == "RK":
# lower order solution is stored in the second to last entry of L.u
return abs(L.u[-2] - L.u[-1])
L.sweep.compute_end_point()
return abs(L.uend - L.sweep.u_secondary)
elif self.params.sweeper_type == "SDC":
# order rises by one between sweeps, making this so ridiculously easy
# order rises by one between sweeps
return abs(L.uold[-1] - L.u[-1])
elif self.params.sweeper_type == 'MPI':
comm = L.sweep.comm
Expand All @@ -109,12 +109,13 @@ def estimate_embedded_error_serial(self, L):

def setup_status_variables(self, controller, **kwargs):
"""
Add the embedded error variable to the error function.
Add the embedded error to the level status
Args:
controller (pySDC.Controller): The controller
"""
self.add_status_variable_to_level('error_embedded_estimate')
self.add_status_variable_to_level('increment')

def post_iteration_processing(self, controller, S, **kwargs):
"""
Expand All @@ -134,6 +135,7 @@ def post_iteration_processing(self, controller, S, **kwargs):
if S.status.iter > 0 or self.params.sweeper_type == "RK":
for L in S.levels:
L.status.error_embedded_estimate = max([self.estimate_embedded_error_serial(L), np.finfo(float).eps])
L.status.increment = L.status.error_embedded_estimate * 1
self.debug(f'L.status.error_embedded_estimate={L.status.error_embedded_estimate:.5e}', S)

return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def post_iteration_processing(self, controller, S, **kwargs):
if self.comm:
buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
self.comm.Bcast(buf, root=rank)
L.status.error_embedded_estimate = buf
L.status.error_embedded_estimate = float(buf)
else:
L.status.error_embedded_estimate = abs(u_inter - high_order_sol)

Expand Down
Loading

0 comments on commit 3b49d5f

Please sign in to comment.