From a8e00e65fd406723f5d44b39bc3e08ae248f45ca Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Sun, 15 Dec 2024 18:32:36 +0100 Subject: [PATCH] remove grid requirements --- src/Models/Models.jl | 3 +++ src/Simulations/simulation.jl | 2 +- src/Simulations/time_step_wizard.jl | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Models/Models.jl b/src/Models/Models.jl index 4981986dd5..0a4fd5d44f 100644 --- a/src/Models/Models.jl +++ b/src/Models/Models.jl @@ -23,6 +23,8 @@ import Oceananigans.Architectures: architecture import Oceananigans.TimeSteppers: reset! import Oceananigans.Solvers: iteration +import Base + # A prototype interface for AbstractModel. # # TODO: decide if we like this. @@ -34,6 +36,7 @@ import Oceananigans.Solvers: iteration iteration(model::AbstractModel) = model.clock.iteration Base.time(model::AbstractModel) = model.clock.time +Base.eltype(model::AbstractModel) = eltype(model.grid) architecture(model::AbstractModel) = model.grid.architecture initialize!(model::AbstractModel) = nothing total_velocities(model::AbstractModel) = nothing diff --git a/src/Simulations/simulation.jl b/src/Simulations/simulation.jl index bd6a0fbbe1..835fa4de2d 100644 --- a/src/Simulations/simulation.jl +++ b/src/Simulations/simulation.jl @@ -73,7 +73,7 @@ function Simulation(model; Δt, # Convert numbers to floating point; otherwise preserve type (eg for DateTime types) # TODO: implement TT = timetype(model) and FT = eltype(model) - TT = eltype(model.grid) + TT = eltype(model) Δt = Δt isa Number ? TT(Δt) : Δt stop_time = stop_time isa Number ? TT(stop_time) : stop_time diff --git a/src/Simulations/time_step_wizard.jl b/src/Simulations/time_step_wizard.jl index d8773d5e7f..cc15f3ea25 100644 --- a/src/Simulations/time_step_wizard.jl +++ b/src/Simulations/time_step_wizard.jl @@ -108,7 +108,7 @@ function new_time_step(old_Δt, wizard, model) new_Δt = min(wizard.max_change * old_Δt, new_Δt) new_Δt = max(wizard.min_change * old_Δt, new_Δt) new_Δt = clamp(new_Δt, wizard.min_Δt, wizard.max_Δt) - new_Δt = all_reduce(min, new_Δt, architecture(model.grid)) + new_Δt = all_reduce(min, new_Δt, architecture(model)) return new_Δt end