From bc115557dc45d62392e4e0ac7e55963bfef14e22 Mon Sep 17 00:00:00 2001 From: Joshua Lampert <51029046+JoshuaLampert@users.noreply.github.com> Date: Wed, 13 Dec 2023 18:54:07 +0100 Subject: [PATCH] introduce type parameters in solver (#77) --- src/solver.jl | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/solver.jl b/src/solver.jl index 1bc2d645..6070ac19 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -10,13 +10,19 @@ abstract type AbstractSolver end A `struct` that holds the summation by parts (SBP) operators that are used for the spatial discretization. """ -struct Solver{RealT <: Real} <: AbstractSolver - D1::AbstractDerivativeOperator{RealT} - D2::Union{AbstractDerivativeOperator{RealT}, AbstractMatrix{RealT}} - - function Solver{RealT}(D1::AbstractDerivativeOperator{RealT}, - D2::Union{AbstractDerivativeOperator{RealT}, - AbstractMatrix{RealT}}) where {RealT} +struct Solver{RealT <: Real, FirstDerivative <: AbstractDerivativeOperator{RealT}, + SecondDerivative <: + Union{AbstractDerivativeOperator{RealT}, AbstractMatrix{RealT}}} <: + AbstractSolver + D1::FirstDerivative + D2::SecondDerivative + + function Solver{RealT, FirstDerivative, SecondDerivative}(D1::FirstDerivative, + D2::SecondDerivative) where { + RealT, + FirstDerivative, + SecondDerivative + } @assert derivative_order(D1) == 1 if D2 isa AbstractDerivativeOperator @assert derivative_order(D2) == 2 @@ -35,7 +41,7 @@ function Solver(mesh, accuracy_order) D1 = periodic_derivative_operator(1, accuracy_order, mesh.xmin, mesh.xmax, mesh.N) D2 = periodic_derivative_operator(2, accuracy_order, mesh.xmin, mesh.xmax, mesh.N) @assert real(D1) == real(D2) - Solver{real(D1)}(D1, D2) + Solver{real(D1), typeof(D1), typeof(D2)}(D1, D2) end # Also allow to pass custom SBP operators (for convenience without explicitly specifying the type) @@ -50,7 +56,7 @@ function Solver(D1::AbstractDerivativeOperator{RealT}, D2::Union{AbstractDerivativeOperator{RealT}, AbstractMatrix{RealT}}) where { RealT } - Solver{RealT}(D1, D2) + Solver{RealT, typeof(D1), typeof(D2)}(D1, D2) end function Base.show(io::IO, solver::Solver{RealT}) where {RealT}