diff --git a/pySDC/implementations/hooks/log_solution.py b/pySDC/implementations/hooks/log_solution.py index f6a1980558..d6c85b8e43 100644 --- a/pySDC/implementations/hooks/log_solution.py +++ b/pySDC/implementations/hooks/log_solution.py @@ -1,4 +1,7 @@ from pySDC.core.Hooks import hooks +import pickle +import os +import numpy as np class LogSolution(hooks): @@ -63,3 +66,79 @@ def post_iteration(self, step, level_number): type='u', value=L.uend, ) + + +class LogToFile(hooks): + r""" + Hook for logging the solution to file after the step using pickle. + + Please configure the hook to your liking by manipulating class attributes. + You must set a custom path to a directory like so: + + ``` + LogToFile.path = '/my/directory/' + ``` + + Keep in mind that the hook will overwrite files without warning! + You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the + index associated with individual files by giving a different lambda function ``format_index`` class attribute. This + lambda should accept one index and return one string. + + You can also give a custom ``logging_condition`` lambda, accepting the current level if you want to log selectively. + + Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution. + Of course, if you are not using numpy, you need to change this. Again, this is a lambda accepting the level. + + After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to + directly load the solution at a given index. Just configure the hook like you did when you recorded the data + beforehand. + + Finally, be aware that using this hook with MPI parallel runs may lead to different tasks overwriting files. Make + sure to give a different `file_name` for each task that writes files. + """ + + path = None + file_name = 'solution' + logging_condition = lambda L: True + process_solution = lambda L: {'t': L.time + L.dt, 'u': L.uend.view(np.ndarray)} + format_index = lambda index: f'{index:06d}' + + def __init__(self): + super().__init__() + self.counter = 0 + + if self.path is None: + raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!') + + if os.path.isfile(self.path): + raise ValueError( + f'{self.path!r} is not a valid path to log to because a file of the same name exists. Please supply a directory' + ) + + if not os.path.isdir(self.path): + os.mkdir(self.path) + + def post_step(self, step, level_number): + if level_number > 0: + return None + + L = step.levels[level_number] + + if type(self).logging_condition(L): + path = self.get_path(self.counter) + data = type(self).process_solution(L) + + with open(path, 'wb') as file: + pickle.dump(data, file) + + self.counter += 1 + + @classmethod + def get_path(cls, index): + return f'{cls.path}/{cls.file_name}_{cls.format_index(index)}.pickle' + + @classmethod + def load(cls, index): + path = cls.get_path(index) + with open(path, 'rb') as file: + return pickle.load(file) diff --git a/pySDC/tests/test_hooks/test_log_to_file.py b/pySDC/tests/test_hooks/test_log_to_file.py new file mode 100644 index 0000000000..0f0d48f0e2 --- /dev/null +++ b/pySDC/tests/test_hooks/test_log_to_file.py @@ -0,0 +1,85 @@ +import pytest + + +def run(hook, Tend=0): + from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d + from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit + from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI + + level_params = {'dt': 1.0e-1} + + sweeper_params = { + 'num_nodes': 1, + 'quad_type': 'GAUSS', + } + + description = { + 'level_params': level_params, + 'sweeper_class': generic_implicit, + 'problem_class': testequation0d, + 'sweeper_params': sweeper_params, + 'problem_params': {}, + 'step_params': {'maxiter': 1}, + } + + controller_params = { + 'hook_class': hook, + 'logger_level': 30, + } + controller = controller_nonMPI(1, controller_params, description) + if Tend > 0: + prob = controller.MS[0].levels[0].prob + u0 = prob.u_exact(0) + + _, stats = controller.run(u0, 0, Tend) + return stats + + +@pytest.mark.base +def test_errors(): + from pySDC.implementations.hooks.log_solution import LogToFile + import os + + with pytest.raises(ValueError): + run(LogToFile) + + LogToFile.path = os.getcwd() + run(LogToFile) + + path = f'{os.getcwd()}/tmp' + LogToFile.path = path + run(LogToFile) + os.path.isdir(path) + + with pytest.raises(ValueError): + LogToFile.path = __file__ + run(LogToFile) + + +@pytest.mark.base +def test_logging(): + from pySDC.implementations.hooks.log_solution import LogToFile, LogSolution + from pySDC.helpers.stats_helper import get_sorted + import os + import pickle + import numpy as np + + path = f'{os.getcwd()}/tmp' + LogToFile.path = path + Tend = 2 + + stats = run([LogToFile, LogSolution], Tend=Tend) + u = get_sorted(stats, type='u') + + u_file = [] + for i in range(len(u)): + data = LogToFile.load(i) + u_file += [(data['t'], data['u'])] + + for us, uf in zip(u, u_file): + assert us[0] == uf[0] + assert np.allclose(us[1], uf[1]) + + +if __name__ == '__main__': + test_logging()