Skip to content

Commit

Permalink
More meta data in statistics (#351)
Browse files Browse the repository at this point in the history
* Custom metadata in statistics

* Fixes
  • Loading branch information
brownbaerchen authored Sep 2, 2023
1 parent 1a51834 commit 314223c
Show file tree
Hide file tree
Showing 12 changed files with 238 additions and 67 deletions.
69 changes: 35 additions & 34 deletions pySDC/core/Hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@
from collections import namedtuple


Entry = namedtuple('Entry', ['process', 'time', 'level', 'iter', 'sweep', 'type', 'num_restarts'])
# metadata with defaults
meta_data = {
'process': None,
'process_sweeper': None,
'time': None,
'level': None,
'iter': None,
'sweep': None,
'type': None,
'num_restarts': None,
}
Entry = namedtuple('Entry', meta_data.keys())


# noinspection PyUnusedLocal,PyShadowingBuiltins,PyShadowingNames
Expand All @@ -18,9 +29,12 @@ class hooks(object):
logger: logger instance for output
__num_restarts (int): number of restarts of the current step
__stats (dict): dictionary for gathering the statistics of a run
__entry (namedtuple): statistics entry containing all information to identify the value
entry (namedtuple): statistics entry containing all information to identify the value
"""

entry = Entry
meta_data = meta_data

def __init__(self):
"""
Initialization routine
Expand All @@ -31,52 +45,39 @@ def __init__(self):

# create statistics and entry elements
self.__stats = {}
self.__entry = Entry

def add_to_stats(self, process, time, level, iter, sweep, type, value):
def add_to_stats(self, value, **kwargs):
"""
Routine to add data to the statistics dict
Routine to add data to the statistics dict. Please supply the metadata as keyword arguments in accordance with
the entry class.
Args:
process: the current process recording this data
time (float): the current simulation time
level (int): the current level index
iter (int): the current iteration count
sweep (int): the current sweep count
type (str): string to describe the type of value
value: the actual data
"""
# create named tuple for the key and add to dict
self.__stats[
self.__entry(
process=process,
time=time,
level=level,
iter=iter,
sweep=sweep,
type=type,
num_restarts=self.__num_restarts,
)
] = value

def increment_stats(self, process, time, level, iter, sweep, type, value, initialize=None):
meta = {
**self.meta_data,
**kwargs,
'num_restarts': self.__num_restarts,
}
self.__stats[self.entry(**meta)] = value

def increment_stats(self, value, initialize=None, **kwargs):
"""
Routine to increment data to the statistics dict. If the data is not yet created, it will be initialized to
initialize if applicable and to value otherwise
initialize if applicable and to value otherwise. Please supply metadata as keyword arguments in accordance with
the entry class.
Args:
process: the current process recording this data
time (float): the current simulation time
level (int): the current level index
iter (int): the current iteration count
sweep (int): the current sweep count
type (str): string to describe the type of value
value: the actual data
initialize: if supplied and data does not exist already, this will be used over value
"""
key = self.__entry(
process=process, time=time, level=level, iter=iter, sweep=sweep, type=type, num_restarts=self.__num_restarts
)
meta = {
**meta_data,
**kwargs,
'num_restarts': self.__num_restarts,
}
key = self.entry(**meta)
if key in self.__stats.keys():
self.__stats[key] += value
elif initialize is not None:
Expand Down
4 changes: 4 additions & 0 deletions pySDC/core/Sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,7 @@ def level(self, L):
"""
assert isinstance(L, level)
self.__level = L

@property
def rank(self):
return 0
24 changes: 4 additions & 20 deletions pySDC/helpers/stats_helper.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,21 @@
import numpy as np


def filter_stats(
stats, process=None, time=None, level=None, iter=None, type=None, recomputed=None, num_restarts=None, comm=None
):
def filter_stats(stats, comm=None, recomputed=None, **kwargs):
"""
Helper function to extract data from the dictionary of statistics
Helper function to extract data from the dictionary of statistics. Please supply metadata as keyword arguments.
Args:
stats (dict): raw statistics from a controller run
process (int): process number
time (float): the requested simulation time
level (int): the requested level index
iter (int): the requested iteration count
type (str): string to describe the requested type of value
recomputed (bool): filter recomputed values from stats if set to anything other than None
comm (mpi4py.MPI.Intracomm): Communicator (or None if not applicable)
Returns:
dict: dictionary containing only the entries corresponding to the filter
"""
result = {}

for k, v in stats.items():
# get data if key matches the filter (if specified)
if (
(k.time == time or time is None)
and (k.process == process or process is None)
and (k.level == level or level is None)
and (k.iter == iter or iter is None)
and (k.type == type or type is None)
and (k.num_restarts == num_restarts or num_restarts is None)
):
if all([k._asdict().get(k2, None) == v2 for k2, v2 in kwargs.items() if v2 is not None] + [True]):
result[k] = v

if comm is not None:
Expand All @@ -55,7 +39,7 @@ def filter_stats(
]

# delete values that were recorded at times that shouldn't be recorded because we performed a different step after the restart
if type != '_recomputed':
if kwargs.get('type', None) != '_recomputed':
other_restarted_steps = [
key for key, val in filter_stats(stats, type='_recomputed', recomputed=False, comm=comm).items() if val
]
Expand Down
20 changes: 19 additions & 1 deletion pySDC/implementations/hooks/default_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def post_comm(self, step, level_number, add_to_stats=False):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand Down Expand Up @@ -179,6 +180,7 @@ def post_sweep(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -188,6 +190,7 @@ def post_sweep(self, step, level_number):
)
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -211,6 +214,7 @@ def post_iteration(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=-1,
iter=step.status.iter,
Expand All @@ -220,6 +224,7 @@ def post_iteration(self, step, level_number):
)
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -243,6 +248,7 @@ def post_step(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -252,6 +258,7 @@ def post_step(self, step, level_number):
)
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=-1,
iter=step.status.iter,
Expand All @@ -261,6 +268,7 @@ def post_step(self, step, level_number):
)
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=-1,
Expand All @@ -272,7 +280,14 @@ def post_step(self, step, level_number):
# record the recomputed quantities at weird positions to make sure there is only one value for each step
for t in [L.time, L.time + L.dt]:
self.add_to_stats(
process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart')
process=-1,
time=t,
level=-1,
iter=-1,
sweep=-1,
type='_recomputed',
value=step.status.get('restart'),
process_sweeper=-1,
)

def post_predict(self, step, level_number):
Expand All @@ -290,6 +305,7 @@ def post_predict(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -313,6 +329,7 @@ def post_run(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -335,6 +352,7 @@ def post_setup(self, step, level_number):

self.add_to_stats(
process=-1,
process_sweeper=-1,
time=-1,
level=-1,
iter=-1,
Expand Down
1 change: 1 addition & 0 deletions pySDC/implementations/hooks/log_embedded_error_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def log_error(self, step, level_number, appendix=''):
value = L.status.error_embedded_estimate
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time + L.dt,
level=L.level_index,
iter=iter,
Expand Down
5 changes: 5 additions & 0 deletions pySDC/implementations/hooks/log_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def log_global_error(self, step, level_number, suffix=''):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time + L.dt,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -41,6 +42,7 @@ def log_global_error(self, step, level_number, suffix=''):
)
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time + L.dt,
level=L.level_index,
iter=step.status.iter,
Expand Down Expand Up @@ -69,6 +71,7 @@ def log_local_error(self, step, level_number, suffix=''):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time + L.dt,
level=L.level_index,
iter=step.status.iter,
Expand Down Expand Up @@ -176,6 +179,7 @@ def post_run(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=self.t_last_solution,
level=L.level_index,
iter=step.status.iter,
Expand All @@ -185,6 +189,7 @@ def post_run(self, step, level_number):
)
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=self.t_last_solution,
level=L.level_index,
iter=step.status.iter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def post_step(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time + L.dt,
level=L.level_index,
iter=step.status.iter,
Expand Down
1 change: 1 addition & 0 deletions pySDC/implementations/hooks/log_restarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def post_step(self, step, level_number):

self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time,
level=L.level_index,
iter=step.status.iter,
Expand Down
1 change: 1 addition & 0 deletions pySDC/implementations/hooks/log_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def post_step(self, step, level_number):
for key in self.__work_last_step[step.status.slot][level_number].keys():
self.add_to_stats(
process=step.status.slot,
process_sweeper=L.sweep.rank,
time=L.time + L.dt,
level=L.level_index,
iter=step.status.iter,
Expand Down
12 changes: 9 additions & 3 deletions pySDC/implementations/sweeper_classes/generic_implicit_MPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ def __init__(self, params):
self.logger.debug('Using MPI.COMM_WORLD for the communicator because none was supplied in the params.')
super().__init__(params)

self.rank = self.params.comm.Get_rank()

if self.params.comm.size != self.coll.num_nodes:
raise NotImplementedError(
f'The communicator in the {type(self).__name__} sweeper needs to have one rank for each node as of now! That means we need {self.coll.num_nodes} nodes, but got {self.params.comm.size} nodes.'
f'The communicator in the {type(self).__name__} sweeper needs to have one rank for each node as of now! That means we need {self.coll.num_nodes} nodes, but got {self.params.comm.size} processes.'
)

@property
def comm(self):
return self.params.comm

@property
def rank(self):
return self.comm.rank

def compute_end_point(self):
"""
Compute u at the right point of the interval
Expand Down
9 changes: 0 additions & 9 deletions pySDC/projects/Resilience/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@ def post_step(self, step, level_number):

L = step.levels[level_number]

self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='restart',
value=int(step.status.get('restart')),
)
# add the following with two names because I use both in different projects -.-
self.increment_stats(
process=step.status.slot,
Expand Down
Loading

0 comments on commit 314223c

Please sign in to comment.