diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 6c9a0c70bc4..36d46853443 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -6,11 +6,11 @@ from cached_property import cached_property from sympy import sympify -from devito import switchconfig +from devito import mpi_switch_log from devito.arch import compiler_registry, platform_registry from devito.data import default_allocator -from devito.exceptions import InvalidOperator -from devito.logger import debug, info, perf, warning, is_log_enabled_for, set_log_level +from devito.exceptions import InvalidOperator, ExecutionError +from devito.logger import debug, info, perf, warning, is_log_enabled_for from devito.ir.equations import LoweredEq, lower_exprs from devito.ir.clusters import ClusterGroup, clusterize from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, MetaCall, @@ -874,15 +874,12 @@ def apply(self, **kwargs): # In case MPI is used restrict result logging to one rank only if configuration['mpi']: - # Only temporarily change configuration - with switchconfig(mpi=True): - set_log_level('DEBUG', comm=args.comm) + with mpi_switch_log(log_level='DEBUG', comm=args.comm): return self._emit_apply_profiling(args) - - return self._emit_apply_profiling(args) + else: + return self._emit_apply_profiling(args) # Performance profiling - def _emit_build_profiling(self): if not is_log_enabled_for('PERF'): return @@ -916,7 +913,6 @@ def _emit_timings(timings, indent=''): def _emit_apply_profiling(self, args): """Produce a performance summary of the profiled sections.""" - # Rounder to 2 decimal places fround = lambda i: ceil(i * 100) / 100 diff --git a/devito/parameters.py b/devito/parameters.py index c97c38b89c5..3977a835eff 100644 --- a/devito/parameters.py +++ b/devito/parameters.py @@ -4,11 +4,11 @@ from os import environ from functools import wraps -from devito.logger import info, warning +from devito.logger import info, warning, logger, stream_handler, set_log_level from devito.tools import Signer, filter_ordered __all__ = ['configuration', 'init_configuration', 'print_defaults', 'print_state', - 'switchconfig'] + 'switchconfig', 'mpi_switch_log'] # Be EXTREMELY careful when writing to a Parameters dictionary # Read here for reference: http://wiki.c2.com/?GlobalVariablesAreBad @@ -258,6 +258,27 @@ def wrapper(*args, **kwargs): return wrapper +class mpi_switch_log(switchconfig): + """ + A context manager subclassing `switchconfig` to temporarily change + MPI logging. + """ + + def __init__(self, **params): + self.params = {k.replace('_', '-'): v for k, v in params.items()} + self.previous = {} + + comm = self.params.pop('comm') + + # Limit logging to rank 0 + set_log_level(self.params['log-level'], comm=comm) + + def __exit__(self, exc_type, exc_val, exc_tb): + # Reinstate logging upon exit + set_log_level(self.previous['log-level']) + logger.addHandler(stream_handler) + + def print_defaults(): """Print the environment variables accepted by Devito, their default value, as well as all of the accepted values."""