diff --git a/pySDC/core/convergence_controller.py b/pySDC/core/convergence_controller.py index bcd2fce0ee..3a60c1128e 100644 --- a/pySDC/core/convergence_controller.py +++ b/pySDC/core/convergence_controller.py @@ -292,6 +292,19 @@ def post_step_processing(self, controller, S, **kwargs): """ pass + def post_run_processing(self, controller, S, **kwargs): + """ + Do whatever you want to after the run here. + + Args: + controller (pySDC.Controller): The controller + S (pySDC.Step): The current step + + Returns: + None + """ + pass + def prepare_next_block(self, controller, S, size, time, Tend, **kwargs): """ Prepare stuff like spreading step sizes or whatever. diff --git a/pySDC/implementations/controller_classes/controller_MPI.py b/pySDC/implementations/controller_classes/controller_MPI.py index 9870593978..829d384b23 100644 --- a/pySDC/implementations/controller_classes/controller_MPI.py +++ b/pySDC/implementations/controller_classes/controller_MPI.py @@ -160,6 +160,9 @@ def run(self, u0, t0, Tend): for hook in self.hooks: hook.post_run(step=self.S, level_number=0) + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.post_run_processing(self, self.S, comm=self.comm) + comm_active.Free() return uend, self.return_stats() diff --git a/pySDC/implementations/controller_classes/controller_nonMPI.py b/pySDC/implementations/controller_classes/controller_nonMPI.py index 0459a702a9..924a6a705f 100644 --- a/pySDC/implementations/controller_classes/controller_nonMPI.py +++ b/pySDC/implementations/controller_classes/controller_nonMPI.py @@ -171,6 +171,10 @@ def run(self, u0, t0, Tend): for hook in self.hooks: hook.post_run(step=S, level_number=0) + for S in self.MS: + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.post_run_processing(self, S, MS=MS_active) + return uend, self.return_stats() def restart_block(self, active_slots, time, u0): diff --git a/pySDC/projects/GPU/configs/base_config.py b/pySDC/projects/GPU/configs/base_config.py index cdf89eec00..00755c6921 100644 --- a/pySDC/projects/GPU/configs/base_config.py +++ b/pySDC/projects/GPU/configs/base_config.py @@ -145,15 +145,6 @@ def get_initial_condition(self, P, *args, restart_idx=0, **kwargs): else: return P.u_exact(t=0), 0 - def get_previous_stats(self, P, restart_idx): - if restart_idx == 0: - return {} - else: - hook = self.get_LogToFile() - path = LogStats.get_stats_path(hook, counter_offset=0) - with open(path, 'rb') as file: - return pickle.load(file) - def get_LogToFile(self): return None @@ -161,8 +152,26 @@ def get_LogToFile(self): class LogStats(ConvergenceController): @staticmethod - def get_stats_path(hook, counter_offset=-1): - return f'{hook.path}/{hook.file_name}_{hook.format_index(hook.counter+counter_offset)}-stats.pickle' + def get_stats_path(hook, counter_offset=-1, index=None): + index = hook.counter + counter_offset if index is None else index + return f'{hook.path}/{hook.file_name}_{hook.format_index(index)}-stats.pickle' + + def merge_all_stats(self, controller): + hook = self.params.hook + + stats = {} + for i in range(hook.counter): + with open(self.get_stats_path(hook, index=i), 'rb') as file: + _stats = pickle.load(file) + stats = {**stats, **_stats} + + stats = {**stats, **controller.return_stats()} + return stats + + def reset_stats(self, controller): + for hook in controller.hooks: + hook.reset_stats() + self.logger.debug('Reset stats') def setup(self, controller, params, *args, **kwargs): params['control_order'] = 999 @@ -181,11 +190,20 @@ def post_step_processing(self, controller, S, **kwargs): for _hook in controller.hooks: _hook.post_step(S, 0) - if self.counter < hook.counter: - path = self.get_stats_path(hook) + while self.counter < hook.counter: + path = self.get_stats_path(hook, index=self.counter) stats = controller.return_stats() if hook.logging_condition(S.levels[0]): with open(path, 'wb') as file: pickle.dump(stats, file) self.log(f'Stored stats in {path!r}', S) + self.reset_stats(controller) self.counter = hook.counter + + def post_run_processing(self, controller, *args, **kwargs): + stats = self.merge_all_stats(controller) + + def return_stats(): + return stats + + controller.return_stats = return_stats diff --git a/pySDC/projects/GPU/run_experiment.py b/pySDC/projects/GPU/run_experiment.py index e095e07ad7..35b7bbc64a 100644 --- a/pySDC/projects/GPU/run_experiment.py +++ b/pySDC/projects/GPU/run_experiment.py @@ -46,12 +46,9 @@ def run_experiment(args, config, **kwargs): u0, t0 = config.get_initial_condition(prob, restart_idx=args['restart_idx']) - previous_stats = config.get_previous_stats(prob, restart_idx=args['restart_idx']) - uend, stats = controller.run(u0=u0, t0=t0, Tend=config.Tend) - combined_stats = {**previous_stats, **stats} - combined_stats = filter_stats(combined_stats, comm=config.comm_world) + combined_stats = filter_stats(stats, comm=config.comm_world) if config.comm_world.rank == config.comm_world.size - 1: path = f'data/{config.get_path()}-stats-whole-run.pickle' diff --git a/pySDC/projects/GPU/tests/test_configs.py b/pySDC/projects/GPU/tests/test_configs.py index f21759712d..ac25d813ba 100644 --- a/pySDC/projects/GPU/tests/test_configs.py +++ b/pySDC/projects/GPU/tests/test_configs.py @@ -83,7 +83,14 @@ def logging_condition(L): LogToFile.logging_condition = logging_condition return LogToFile - args = {'procs': [1, 1, 1], 'useGPU': False, 'res': -1, 'logger_level': 15, 'restart_idx': restart_idx} + args = { + 'procs': [1, 1, 1], + 'useGPU': False, + 'res': -1, + 'logger_level': 15, + 'restart_idx': restart_idx, + 'mode': 'run', + } config = VdPConfig(args) run_experiment(args, config)