Skip to content

Commit

Permalink
Starting on ML tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen committed Jan 24, 2025
1 parent 0d6f628 commit a5d30dc
Showing 1 changed file with 60 additions and 4 deletions.
64 changes: 60 additions & 4 deletions pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def setup(useMPIsweeper):
problem_params['nu'] = 0.1
problem_params['n'] = 128
problem_params['c'] = 1.0
problem_params['order'] = [4, 1]
problem_params['comm'] = ensemble.space_comm

controller_params = dict()
Expand All @@ -65,7 +66,56 @@ def setup(useMPIsweeper):
return description, controller_params


def runHeatFiredrake(useMPIsweeper):
def setup_ML():
"""
Helper routine to set up parameters
Returns:
description and controller_params parameter dictionaries
"""
from pySDC.implementations.problem_classes.HeatFiredrake import Heat1DForcedFiredrake
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
from pySDC.implementations.sweeper_classes.imex_1st_order_MPI import imex_1st_order_MPI
from pySDC.implementations.transfer_classes.TransferFiredrakeMesh import MeshToMeshFiredrake
from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
from pySDC.implementations.hooks.log_work import LogWork
from pySDC.helpers.firedrake_ensemble_communicator import FiredrakeEnsembleCommunicator

level_params = dict()
level_params['restol'] = 5e-10
level_params['dt'] = 0.2

step_params = dict()
step_params['maxiter'] = 20

sweeper_params = dict()
sweeper_params['quad_type'] = 'RADAU-RIGHT'
sweeper_params['num_nodes'] = [3, 1]
sweeper_params['QI'] = 'LU'
sweeper_params['QE'] = 'PIC'

problem_params = dict()
problem_params['nu'] = 0.1
problem_params['n'] = [128, 32]
problem_params['c'] = 1.0

controller_params = dict()
controller_params['logger_level'] = 15 if MPI.COMM_WORLD.rank == 0 else 30
controller_params['hook_class'] = [LogGlobalErrorPostRun, LogWork]

description = dict()
description['problem_class'] = Heat1DForcedFiredrake
description['problem_params'] = problem_params
description['sweeper_class'] = imex_1st_order
description['sweeper_params'] = sweeper_params
description['level_params'] = level_params
description['step_params'] = step_params
description['space_transfer_class'] = MeshToMeshFiredrake

return description, controller_params


def runHeatFiredrake(useMPIsweeper=False, ML=False):
"""
Run the example defined by the above parameters
"""
Expand All @@ -75,7 +125,11 @@ def runHeatFiredrake(useMPIsweeper):
Tend = 1.0
t0 = 0.0

description, controller_params = setup(useMPIsweeper)
if ML:
assert not useMPIsweeper, 'MPI parallel diagonal SDC and ML SDC are not compatible at the moment'
description, controller_params = setup_ML()
else:
description, controller_params = setup(useMPIsweeper)

controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)

Expand All @@ -98,8 +152,9 @@ def runHeatFiredrake(useMPIsweeper):
tot_solves = np.sum([me[1] for me in work_solves])
tot_rhs = np.sum([me[1] for me in work_rhs])

time_rank = description["sweeper_params"]["comm"].rank if useMPIsweeper else 0
print(
f'Finished with error {error[0][1]:.2e}. Used {tot_iter} SDC iterations, with {tot_solver_setup} solver setups, {tot_solves} solves and {tot_rhs} right hand side evaluations on time task {description["sweeper_params"]["comm"].rank}.'
f'Finished with error {error[0][1]:.2e}. Used {tot_iter} SDC iterations, with {tot_solver_setup} solver setups, {tot_solves} solves and {tot_rhs} right hand side evaluations on time task {time_rank}.'
)

# do tests that we got the same as last time
Expand All @@ -112,4 +167,5 @@ def runHeatFiredrake(useMPIsweeper):


if __name__ == "__main__":
runHeatFiredrake(useMPIsweeper=MPI.COMM_WORLD.size > 1)
# runHeatFiredrake(useMPIsweeper=MPI.COMM_WORLD.size > 1)
runHeatFiredrake(ML=True)

0 comments on commit a5d30dc

Please sign in to comment.