Skip to content

Commit

Permalink
extract_semidiscretization(integrator)
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Nov 9, 2022
1 parent d7f05a3 commit 88dfa20
Show file tree
Hide file tree
Showing 15 changed files with 32 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Polyester = "0.3.4, 0.5, 0.6"
RecipesBase = "1.1"
Reexport = "1.0"
Requires = "1.1"
SciMLBase = "1.21"
SciMLBase = "1.63"
Setfield = "0.8, 1"
StartUpDG = "0.14"
Static = "0.3, 0.4, 0.5, 0.6, 0.7"
Expand Down
3 changes: 2 additions & 1 deletion src/Trixi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, sparse, dropt
# import @reexport now to make it available for further imports/exports
using Reexport: @reexport

using SciMLBase: CallbackSet, DiscreteCallback,
using SciMLBase: SciMLBase, unwrapped_f,
CallbackSet, DiscreteCallback,
ODEProblem, ODESolution, ODEFunction,
SplitODEProblem
import SciMLBase: get_du, get_tmp_cache, u_modified!,
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/amr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ end

function initialize!(cb::DiscreteCallback{Condition,Affect!}, u, t, integrator) where {Condition, Affect!<:AMRCallback}
amr_callback = cb.affect!
semi = integrator.p
semi = extract_semidiscretization(integrator)

@trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
# iterate until mesh does not change anymore
Expand Down Expand Up @@ -161,7 +161,7 @@ end

function (amr_callback::AMRCallback)(integrator; kwargs...)
u_ode = integrator.u
semi = integrator.p
semi = extract_semidiscretization(integrator)

@trixi_timeit timer() "AMR" begin
has_changed = amr_callback(u_ode, semi,
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ end


function initialize!(cb::DiscreteCallback{Condition,Affect!}, u_ode, t, integrator) where {Condition, Affect!<:AnalysisCallback}
semi = integrator.p
semi = extract_semidiscretization(integrator)
initial_state_integrals = integrate(u_ode, semi)
_, equations, _, _ = mesh_equations_solver_cache(semi)

Expand Down Expand Up @@ -178,7 +178,7 @@ end

# TODO: Taal refactor, allow passing an IO object (which could be devnull to avoid cluttering the console)
function (analysis_callback::AnalysisCallback)(integrator)
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
@unpack dt, t = integrator
iter = integrator.destats.naccept
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/averaging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end

function initialize!(cb::DiscreteCallback{Condition,Affect!}, u_ode, t, integrator) where {Condition, Affect!<:AveragingCallback}
averaging_callback = cb.affect!
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
u = wrap_array(u_ode, mesh, equations, solver, cache)

Expand All @@ -86,7 +86,7 @@ function (averaging_callback::AveragingCallback)(integrator)

u_ode = integrator.u
u_prev_ode = integrator.uprev
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
u = wrap_array(u_ode, mesh, equations, solver, cache)
u_prev = wrap_array(u_prev_ode, mesh, equations, solver, cache)
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks_step/glm_speed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end
@inline function (glm_speed_callback::GlmSpeedCallback)(integrator)

dt = get_proposed_dt(integrator)
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
@unpack glm_scale, cfl = glm_speed_callback

Expand Down
2 changes: 1 addition & 1 deletion src/callbacks_step/lbm_collision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
@inline function lbm_collision_callback(integrator)

dt = get_proposed_dt(integrator)
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
@unpack collision_op = equations

Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/save_restart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function initialize!(cb::DiscreteCallback{Condition,Affect!}, u, t, integrator)

mpi_isroot() && mkpath(restart_callback.output_directory)

semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, _, _, _ = mesh_equations_solver_cache(semi)
@trixi_timeit timer() "I/O" begin
if mesh.unsaved_changes
Expand Down Expand Up @@ -95,7 +95,7 @@ function (restart_callback::SaveRestartCallback)(integrator)
u_ode = integrator.u
@unpack t, dt = integrator
iter = integrator.destats.naccept
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, _, _, _ = mesh_equations_solver_cache(semi)

@trixi_timeit timer() "I/O" begin
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/save_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function initialize!(cb::DiscreteCallback{Condition,Affect!}, u, t, integrator)

mpi_isroot() && mkpath(solution_callback.output_directory)

semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, _, _, _ = mesh_equations_solver_cache(semi)
@trixi_timeit timer() "I/O" begin
if mesh.unsaved_changes
Expand Down Expand Up @@ -110,7 +110,7 @@ function (solution_callback::SaveSolutionCallback)(integrator)
u_ode = integrator.u
@unpack t, dt = integrator
iter = integrator.destats.naccept
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, _, _, _ = mesh_equations_solver_cache(semi)

@trixi_timeit timer() "I/O" begin
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks_step/steady_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end

# the condition
function (steady_state_callback::SteadyStateCallback)(u_ode, t, integrator)
semi = integrator.p
semi = extract_semidiscretization(integrator)

u = wrap_array(u_ode, semi)
du = wrap_array(get_du(integrator), semi)
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks_step/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end
if !integrator.opts.adaptive
t = integrator.t
u_ode = integrator.u
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
@unpack cfl_number = stepsize_callback
u = wrap_array(u_ode, mesh, equations, solver, cache)
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks_step/summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ function initialize_summary_callback(cb::DiscreteCallback, u, t, integrator)
:total_width => 100,
:indentation_level => 0)

semi = integrator.p
semi = extract_semidiscretization(integrator)
show(io_context, MIME"text/plain"(), semi)
println(io, "\n")
mesh, equations, solver, _ = mesh_equations_solver_cache(semi)
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/time_series.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function (time_series_callback::TimeSeriesCallback)(integrator)

# Unpack data
u_ode = integrator.u
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
u = wrap_array(u_ode, mesh, equations, solver, cache)

Expand All @@ -196,7 +196,7 @@ function (time_series_callback::TimeSeriesCallback)(integrator)

# Store time_series if this is the last time step
if isfinished(integrator)
semi = integrator.p
semi = extract_semidiscretization(integrator)
mesh, equations, solver, _ = mesh_equations_solver_cache(semi)
save_time_series_file(time_series_callback, mesh, equations, solver)
end
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks_step/visualization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ end
# this method is called when the callback is activated
function (visualization_callback::VisualizationCallback)(integrator)
u_ode = integrator.u
semi = integrator.p
semi = extract_semidiscretization(integrator)
@unpack plot_arguments, solution_variables, variable_names, show_mesh, plot_data_creator, plot_creator = visualization_callback

# Extract plot data
Expand Down
11 changes: 11 additions & 0 deletions src/semidiscretization/semidiscretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ function semidiscretize(semi::AbstractSemidiscretization, tspan)
end


# get the semidiscretization from an `ODEIntegrator`
function extract_semidiscretization(integrator)
f = unwrapped_f(integrator.f.f)
if f isa RHSWrapper
return f.semi
else
return integrator.p
end
end


"""
semidiscretize(semi::AbstractSemidiscretization, tspan, restart_file::AbstractString)
Expand Down

0 comments on commit 88dfa20

Please sign in to comment.