diff --git a/Project.toml b/Project.toml index d54d7b338bd..7475ae9bce2 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ Polyester = "0.3.4, 0.5, 0.6" RecipesBase = "1.1" Reexport = "1.0" Requires = "1.1" -SciMLBase = "1.21" +SciMLBase = "1.63" Setfield = "0.8, 1" StartUpDG = "0.14" Static = "0.3, 0.4, 0.5, 0.6, 0.7" diff --git a/src/Trixi.jl b/src/Trixi.jl index 5811a950505..2ec92d67532 100644 --- a/src/Trixi.jl +++ b/src/Trixi.jl @@ -25,7 +25,8 @@ using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, sparse, dropt # import @reexport now to make it available for further imports/exports using Reexport: @reexport -using SciMLBase: CallbackSet, DiscreteCallback, +using SciMLBase: SciMLBase, unwrapped_f, + CallbackSet, DiscreteCallback, ODEProblem, ODESolution, ODEFunction, SplitODEProblem import SciMLBase: get_du, get_tmp_cache, u_modified!, diff --git a/src/callbacks_step/amr.jl b/src/callbacks_step/amr.jl index 6a1a73c838f..0bd14d9eeef 100644 --- a/src/callbacks_step/amr.jl +++ b/src/callbacks_step/amr.jl @@ -122,7 +122,7 @@ end function initialize!(cb::DiscreteCallback{Condition,Affect!}, u, t, integrator) where {Condition, Affect!<:AMRCallback} amr_callback = cb.affect! - semi = integrator.p + semi = extract_semidiscretization(integrator) @trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition # iterate until mesh does not change anymore @@ -161,7 +161,7 @@ end function (amr_callback::AMRCallback)(integrator; kwargs...) u_ode = integrator.u - semi = integrator.p + semi = extract_semidiscretization(integrator) @trixi_timeit timer() "AMR" begin has_changed = amr_callback(u_ode, semi, diff --git a/src/callbacks_step/analysis.jl b/src/callbacks_step/analysis.jl index 59b2d2e9113..681a8d38306 100644 --- a/src/callbacks_step/analysis.jl +++ b/src/callbacks_step/analysis.jl @@ -114,7 +114,7 @@ end function initialize!(cb::DiscreteCallback{Condition,Affect!}, u_ode, t, integrator) where {Condition, Affect!<:AnalysisCallback} - semi = integrator.p + semi = extract_semidiscretization(integrator) initial_state_integrals = integrate(u_ode, semi) _, equations, _, _ = mesh_equations_solver_cache(semi) @@ -178,7 +178,7 @@ end # TODO: Taal refactor, allow passing an IO object (which could be devnull to avoid cluttering the console) function (analysis_callback::AnalysisCallback)(integrator) - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, cache = mesh_equations_solver_cache(semi) @unpack dt, t = integrator iter = integrator.destats.naccept diff --git a/src/callbacks_step/averaging.jl b/src/callbacks_step/averaging.jl index 1052efe4bee..51e7dd15e6f 100644 --- a/src/callbacks_step/averaging.jl +++ b/src/callbacks_step/averaging.jl @@ -67,7 +67,7 @@ end function initialize!(cb::DiscreteCallback{Condition,Affect!}, u_ode, t, integrator) where {Condition, Affect!<:AveragingCallback} averaging_callback = cb.affect! - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, cache = mesh_equations_solver_cache(semi) u = wrap_array(u_ode, mesh, equations, solver, cache) @@ -86,7 +86,7 @@ function (averaging_callback::AveragingCallback)(integrator) u_ode = integrator.u u_prev_ode = integrator.uprev - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, cache = mesh_equations_solver_cache(semi) u = wrap_array(u_ode, mesh, equations, solver, cache) u_prev = wrap_array(u_prev_ode, mesh, equations, solver, cache) diff --git a/src/callbacks_step/glm_speed.jl b/src/callbacks_step/glm_speed.jl index 03809c97e83..5a59a17cc4f 100644 --- a/src/callbacks_step/glm_speed.jl +++ b/src/callbacks_step/glm_speed.jl @@ -74,7 +74,7 @@ end @inline function (glm_speed_callback::GlmSpeedCallback)(integrator) dt = get_proposed_dt(integrator) - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, cache = mesh_equations_solver_cache(semi) @unpack glm_scale, cfl = glm_speed_callback diff --git a/src/callbacks_step/lbm_collision.jl b/src/callbacks_step/lbm_collision.jl index 7bd11830c63..1d46f18383e 100644 --- a/src/callbacks_step/lbm_collision.jl +++ b/src/callbacks_step/lbm_collision.jl @@ -49,7 +49,7 @@ end @inline function lbm_collision_callback(integrator) dt = get_proposed_dt(integrator) - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, cache = mesh_equations_solver_cache(semi) @unpack collision_op = equations diff --git a/src/callbacks_step/save_restart.jl b/src/callbacks_step/save_restart.jl index 33ce0910ba9..c5c7d9ff026 100644 --- a/src/callbacks_step/save_restart.jl +++ b/src/callbacks_step/save_restart.jl @@ -62,7 +62,7 @@ function initialize!(cb::DiscreteCallback{Condition,Affect!}, u, t, integrator) mpi_isroot() && mkpath(restart_callback.output_directory) - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, _, _, _ = mesh_equations_solver_cache(semi) @trixi_timeit timer() "I/O" begin if mesh.unsaved_changes @@ -95,7 +95,7 @@ function (restart_callback::SaveRestartCallback)(integrator) u_ode = integrator.u @unpack t, dt = integrator iter = integrator.destats.naccept - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, _, _, _ = mesh_equations_solver_cache(semi) @trixi_timeit timer() "I/O" begin diff --git a/src/callbacks_step/save_solution.jl b/src/callbacks_step/save_solution.jl index 1efa2146ca3..5dc835bebbd 100644 --- a/src/callbacks_step/save_solution.jl +++ b/src/callbacks_step/save_solution.jl @@ -73,7 +73,7 @@ function initialize!(cb::DiscreteCallback{Condition,Affect!}, u, t, integrator) mpi_isroot() && mkpath(solution_callback.output_directory) - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, _, _, _ = mesh_equations_solver_cache(semi) @trixi_timeit timer() "I/O" begin if mesh.unsaved_changes @@ -110,7 +110,7 @@ function (solution_callback::SaveSolutionCallback)(integrator) u_ode = integrator.u @unpack t, dt = integrator iter = integrator.destats.naccept - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, _, _, _ = mesh_equations_solver_cache(semi) @trixi_timeit timer() "I/O" begin diff --git a/src/callbacks_step/steady_state.jl b/src/callbacks_step/steady_state.jl index 66d04fea704..743655ba60d 100644 --- a/src/callbacks_step/steady_state.jl +++ b/src/callbacks_step/steady_state.jl @@ -56,7 +56,7 @@ end # the condition function (steady_state_callback::SteadyStateCallback)(u_ode, t, integrator) - semi = integrator.p + semi = extract_semidiscretization(integrator) u = wrap_array(u_ode, semi) du = wrap_array(get_du(integrator), semi) diff --git a/src/callbacks_step/stepsize.jl b/src/callbacks_step/stepsize.jl index 13e4f9dfa54..56201ee76b6 100644 --- a/src/callbacks_step/stepsize.jl +++ b/src/callbacks_step/stepsize.jl @@ -67,7 +67,7 @@ end if !integrator.opts.adaptive t = integrator.t u_ode = integrator.u - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, cache = mesh_equations_solver_cache(semi) @unpack cfl_number = stepsize_callback u = wrap_array(u_ode, mesh, equations, solver, cache) diff --git a/src/callbacks_step/summary.jl b/src/callbacks_step/summary.jl index 147af72af44..fae9d459f2c 100644 --- a/src/callbacks_step/summary.jl +++ b/src/callbacks_step/summary.jl @@ -148,7 +148,7 @@ function initialize_summary_callback(cb::DiscreteCallback, u, t, integrator) :total_width => 100, :indentation_level => 0) - semi = integrator.p + semi = extract_semidiscretization(integrator) show(io_context, MIME"text/plain"(), semi) println(io, "\n") mesh, equations, solver, _ = mesh_equations_solver_cache(semi) diff --git a/src/callbacks_step/time_series.jl b/src/callbacks_step/time_series.jl index 9bce17eadd0..cf649ddcf8d 100644 --- a/src/callbacks_step/time_series.jl +++ b/src/callbacks_step/time_series.jl @@ -181,7 +181,7 @@ function (time_series_callback::TimeSeriesCallback)(integrator) # Unpack data u_ode = integrator.u - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, cache = mesh_equations_solver_cache(semi) u = wrap_array(u_ode, mesh, equations, solver, cache) @@ -196,7 +196,7 @@ function (time_series_callback::TimeSeriesCallback)(integrator) # Store time_series if this is the last time step if isfinished(integrator) - semi = integrator.p + semi = extract_semidiscretization(integrator) mesh, equations, solver, _ = mesh_equations_solver_cache(semi) save_time_series_file(time_series_callback, mesh, equations, solver) end diff --git a/src/callbacks_step/visualization.jl b/src/callbacks_step/visualization.jl index 5db0b932751..c94a5cb8ec8 100644 --- a/src/callbacks_step/visualization.jl +++ b/src/callbacks_step/visualization.jl @@ -139,7 +139,7 @@ end # this method is called when the callback is activated function (visualization_callback::VisualizationCallback)(integrator) u_ode = integrator.u - semi = integrator.p + semi = extract_semidiscretization(integrator) @unpack plot_arguments, solution_variables, variable_names, show_mesh, plot_data_creator, plot_creator = visualization_callback # Extract plot data diff --git a/src/semidiscretization/semidiscretization.jl b/src/semidiscretization/semidiscretization.jl index 8274b18427d..4e2ea5bb266 100644 --- a/src/semidiscretization/semidiscretization.jl +++ b/src/semidiscretization/semidiscretization.jl @@ -92,6 +92,17 @@ function semidiscretize(semi::AbstractSemidiscretization, tspan) end +# get the semidiscretization from an `ODEIntegrator` +function extract_semidiscretization(integrator) + f = unwrapped_f(integrator.f.f) + if f isa RHSWrapper + return f.semi + else + return integrator.p + end +end + + """ semidiscretize(semi::AbstractSemidiscretization, tspan, restart_file::AbstractString)