Skip to content

Commit

Permalink
Added hook for storing solution to file after regular time increment (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen authored Sep 19, 2024
1 parent 2848895 commit 7ded866
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
56 changes: 48 additions & 8 deletions pySDC/implementations/hooks/log_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ class LogToFile(Hooks):
Keep in mind that the hook will overwrite files without warning!
You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the
index associated with individual files by giving a different lambda function ``format_index`` class attribute. This
lambda should accept one index and return one string.
index associated with individual files by giving a different function ``format_index`` class attribute. This should
accept one index and return one string.
You can also give a custom ``logging_condition`` function, accepting the current level if you want to log selectively.
Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution.
Of course, if you are not using numpy, you need to change this. Again, this is a lambda accepting the level.
Of course, if you are not using numpy, you need to change this. Again, this is a function accepting the level.
After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to
directly load the solution at a given index. Just configure the hook like you did when you recorded the data
Expand All @@ -99,6 +99,7 @@ class LogToFile(Hooks):

path = None
file_name = 'solution'
counter = 0

def logging_condition(L):
return True
Expand All @@ -111,7 +112,6 @@ def format_index(index):

def __init__(self):
super().__init__()
self.counter = 0

if self.path is None:
raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!')
Expand All @@ -124,20 +124,41 @@ def __init__(self):
if not os.path.isdir(self.path):
os.mkdir(self.path)

def post_step(self, step, level_number):
def log_to_file(self, step, level_number, condition, process_solution=None):
if level_number > 0:
return None

L = step.levels[level_number]

if type(self).logging_condition(L):
if condition:
path = self.get_path(self.counter)
data = type(self).process_solution(L)

if process_solution:
data = process_solution(L)
else:
data = type(self).process_solution(L)

with open(path, 'wb') as file:
pickle.dump(data, file)
self.logger.info(f'Stored file {path!r}')

type(self).counter += 1

def post_step(self, step, level_number):
L = step.levels[level_number]
self.log_to_file(step, level_number, type(self).logging_condition(L))

def pre_run(self, step, level_number):
L = step.levels[level_number]
L.uend = L.u[0]

def process_solution(L):
return {
**type(self).process_solution(L),
't': L.time,
}

self.counter += 1
self.log_to_file(step, level_number, True, process_solution=process_solution)

@classmethod
def get_path(cls, index):
Expand All @@ -148,3 +169,22 @@ def load(cls, index):
path = cls.get_path(index)
with open(path, 'rb') as file:
return pickle.load(file)


class LogToFileAfterXs(LogToFile):
r'''
Log to file after certain amount of time has passed instead of after every step
'''

time_increment = 0
t_next_log = 0

def post_step(self, step, level_number):
L = step.levels[level_number]

if self.t_next_log == 0:
self.t_next_log = self.time_increment

if L.time + L.dt >= self.t_next_log and not step.status.restart:
super().post_step(step, level_number)
self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment
6 changes: 3 additions & 3 deletions pySDC/tests/test_hooks/test_log_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run(hook, Tend=0):
u0 = prob.u_exact(0)

_, stats = controller.run(u0, 0, Tend)
return stats
return u0, stats


@pytest.mark.base
Expand Down Expand Up @@ -68,8 +68,8 @@ def test_logging():
LogToFile.path = path
Tend = 2

stats = run([LogToFile, LogSolution], Tend=Tend)
u = get_sorted(stats, type='u')
u0, stats = run([LogToFile, LogSolution], Tend=Tend)
u = [(0.0, u0)] + get_sorted(stats, type='u')

u_file = []
for i in range(len(u)):
Expand Down

0 comments on commit 7ded866

Please sign in to comment.