Skip to content

Commit

Permalink
Minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen committed Jan 22, 2025
1 parent 5bc126a commit 9529f63
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 69 deletions.
80 changes: 12 additions & 68 deletions pySDC/implementations/problem_classes/GenericGusto.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,73 +133,6 @@ def __init__(
self.solver_parameters = solver_parameters
self.stop_at_divergence = stop_at_divergence

if len(active_labels) > 0:
self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, *active_labels)), map_if_false=drop
)

self.evaluate_source = []
self.physics_names = []
for t in self.residual:
if t.has_label(physics_label):
physics_name = t.get(physics_label)
if t.labels[physics_name] not in self.physics_names:
self.evaluate_source.append(t.labels[physics_name])
self.physics_names.append(t.labels[physics_name])

# Check if there are any mass-weighted terms:
if len(self.residual.label_map(lambda t: t.has_label(mass_weighted), map_if_false=drop)) > 0:
for field in equation.field_names:

# Check if the mass term for this prognostic is mass-weighted
if (
len(
self.residual.label_map(
(
lambda t: t.get(prognostic) == field
and t.has_label(time_derivative)
and t.has_label(mass_weighted)
),
map_if_false=drop,
)
)
== 1
):

field_terms = self.residual.label_map(
lambda t: t.get(prognostic) == field and not t.has_label(time_derivative), map_if_false=drop
)

# Check that the equation for this prognostic does not involve
# both mass-weighted and non-mass-weighted terms; if so, a split
# timestepper should be used instead.
if len(field_terms.label_map(lambda t: t.has_label(mass_weighted), map_if_false=drop)) > 0:
if len(field_terms.label_map(lambda t: not t.has_label(mass_weighted), map_if_false=drop)) > 0:
raise ValueError(
'Mass-weighted and non-mass-weighted terms are present in a '
+ f'timestepping equation for {field}. As these terms cannot '
+ 'be solved for simultaneously, a split timestepping method '
+ 'should be used instead.'
)
else:
# Replace the terms with a mass_weighted label with the
# mass_weighted form. It is important that the labels from
# this new form are used.
self.residual = self.residual.label_map(
lambda t: t.get(prognostic) == field and t.has_label(mass_weighted),
map_if_true=lambda t: t.get(mass_weighted),
)
self.idx = None

# -------------------------------------------------------------------- #
# Make boundary conditions
# -------------------------------------------------------------------- #

if not apply_bcs:
self.bcs = None
else:
self.bcs = equation.bcs[equation.field_name]

# -------------------------------------------------------------------- #
# Setup caches
# -------------------------------------------------------------------- #
Expand All @@ -209,12 +142,23 @@ def __init__(
self._u = fd.Function(self.fs)

super().__init__(self.fs)
self._makeAttributeAndRegister('LHS_cache_size', localVars=locals(), readOnly=True)
self._makeAttributeAndRegister('LHS_cache_size', 'apply_bcs', localVars=locals(), readOnly=True)
self.work_counters['rhs'] = WorkCounter()
self.work_counters['ksp'] = WorkCounter()
self.work_counters['solver_setup'] = WorkCounter()
self.work_counters['solver'] = WorkCounter()

# @property
# def residual(self):
# return self.equation.residual

@property
def bcs(self):
if not self.apply_bcs:
return None
else:
return self.equation.bcs[self.equation.field_name]

def invert_mass_matrix(self, rhs):
self._u.assign(rhs.functionspace)

Expand Down
2 changes: 1 addition & 1 deletion pySDC/tests/test_problems/test_generic_gusto.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def test_pySDC_integrator_RK(use_transport_scheme, method):
# Setup time steppers
# ------------------------------------------------------------------------ #

stepper_gusto = get_gusto_stepper(eqns, gusto_method(domain, solver_parameters=solver_parameters), spatial_methods)
stepper_pySDC = get_gusto_stepper(
eqns,
pySDC_integrator(
Expand All @@ -269,6 +268,7 @@ def test_pySDC_integrator_RK(use_transport_scheme, method):
),
spatial_methods,
)
stepper_gusto = get_gusto_stepper(eqns, gusto_method(domain, solver_parameters=solver_parameters), spatial_methods)

# ------------------------------------------------------------------------ #
# Run tests
Expand Down

0 comments on commit 9529f63

Please sign in to comment.