Skip to content

Commit

Permalink
Merge pull request #1973 from ranocha/hr/FineRKN5
Browse files Browse the repository at this point in the history
fix out-of-place version of FineRKN5
  • Loading branch information
ChrisRackauckas authored Jun 20, 2023
2 parents f569818 + 36c5fa6 commit 3d84dcb
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/perform_step/rkn_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end

f.f1(k.x[1], du, u, p, t + dt)
f.f2(k.x[2], du, u, p, t + dt)
integrator.stats.nf += 1
integrator.stats.nf += 4
integrator.stats.nf2 += 1
end

Expand Down Expand Up @@ -159,7 +159,7 @@ end
abar76 * k6) # abar72 = 0

k7 = f.f1(kdu, ku, p, t + dt * c7)
u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # no b6, b7
u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # no b6, b7
du = duprev + dt * (bbar1 * k1 + bbar3 * k3 + bbar4 * k4 + bbar5 * k5 + bbar6 * k6) # no b2, b7

integrator.u = ArrayPartition((du, u))
Expand Down Expand Up @@ -223,7 +223,6 @@ end
@.. broadcast=false ku=uprev +
dt * (c7 * duprev +
dt * (a71 * k1 + a73 * k3 + a74 * k4 + a75 * k5)) # a72 = a76 = 0

@.. broadcast=false kdu=duprev +
dt * (abar71 * k1 + abar73 * k3 + abar74 * k4 +
abar75 * k5 + abar76 * k6) # abar72 = 0
Expand All @@ -242,12 +241,12 @@ end
if integrator.opts.adaptive
duhat, uhat = utilde.x
dtsq = dt^2
@.. broadcast=false uhat = dtsq *
(btilde1 * k1 + btilde3 * k3 + btilde4 * k4 +
btilde5 * k5)
@.. broadcast=false duhat = dt *
(bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 +
bptilde5 * k5 + bptilde6 * k6 + bptilde7 * k7)
@.. broadcast=false uhat=dtsq *
(btilde1 * k1 + btilde3 * k3 + btilde4 * k4 +
btilde5 * k5)
@.. broadcast=false duhat=dt *
(bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 +
bptilde5 * k5 + bptilde6 * k6 + bptilde7 * k7)

calculate_residuals!(atmp, utilde, integrator.uprev, integrator.u,
integrator.opts.abstol, integrator.opts.reltol,
Expand Down
172 changes: 172 additions & 0 deletions test/algconvergence/partitioned_methods_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,175 @@ sol = solve(prob, ERKN5(), reltol = 1e-8)
@test length(sol.u) < 34
sol = solve(prob, ERKN7(), reltol = 1e-8)
@test length(sol.u) < 38

# Compare in-place and out-of-place versions
function damped_oscillator(du, u, p, t)
return -u - 0.5 * du
end
function damped_oscillator!(ddu, du, u, p, t)
@. ddu = -u - 0.5 * du
return nothing
end
@testset "in-place vs. out-of-place" begin
ode_i = SecondOrderODEProblem(damped_oscillator!,
[0.0], [1.0],
(0.0, 10.0))
ode_o = SecondOrderODEProblem(damped_oscillator,
[0.0], [1.0],
(0.0, 10.0))

@testset "Nystrom4" begin
alg = Nystrom4()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, dt = dt)
sol_o = solve(ode_o, alg, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 4 * sol_i.destats.naccept) < 4
end

@testset "FineRKN5" begin
alg = FineRKN5()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, adaptive = false, dt = dt)
sol_o = solve(ode_o, alg, adaptive = false, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 7 * sol_i.destats.naccept) < 4
# adaptive time step
sol_i = solve(ode_i, alg)
sol_o = solve(ode_o, alg)
@test_broken sol_i.t sol_o.t
@test_broken sol_i.u sol_o.u
end

@testset "DPRKN4" begin
alg = DPRKN4()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, adaptive = false, dt = dt)
sol_o = solve(ode_o, alg, adaptive = false, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 4 * sol_i.destats.naccept) < 4
# adaptive time step
sol_i = solve(ode_i, alg)
sol_o = solve(ode_o, alg)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
end

@testset "DPRKN5" begin
alg = DPRKN5()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, adaptive = false, dt = dt)
sol_o = solve(ode_o, alg, adaptive = false, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 6 * sol_i.destats.naccept) < 4
# adaptive time step
sol_i = solve(ode_i, alg)
sol_o = solve(ode_o, alg)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
end

@testset "DPRKN6" begin
alg = DPRKN6()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, adaptive = false, dt = dt)
sol_o = solve(ode_o, alg, adaptive = false, dt = dt)
@test sol_i.t sol_o.t
@test_broken sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 6 * sol_i.destats.naccept) < 4
# adaptive time step
sol_i = solve(ode_i, alg)
sol_o = solve(ode_o, alg)
@test_broken sol_i.t sol_o.t
@test_broken sol_i.u sol_o.u
end

@testset "DPRKN6FM" begin
alg = DPRKN6FM()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, adaptive = false, dt = dt)
sol_o = solve(ode_o, alg, adaptive = false, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 6 * sol_i.destats.naccept) < 4
# adaptive time step
sol_i = solve(ode_i, alg)
sol_o = solve(ode_o, alg)
@test_broken sol_i.t sol_o.t
@test_broken sol_i.u sol_o.u
end

@testset "DPRKN8" begin
alg = DPRKN8()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, adaptive = false, dt = dt)
sol_o = solve(ode_o, alg, adaptive = false, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 9 * sol_i.destats.naccept) < 4
# adaptive time step
sol_i = solve(ode_i, alg)
sol_o = solve(ode_o, alg)
@test_broken sol_i.t sol_o.t
@test_broken sol_i.u sol_o.u
end

@testset "DPRKN12" begin
alg = DPRKN12()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, adaptive = false, dt = dt)
sol_o = solve(ode_o, alg, adaptive = false, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 17 * sol_i.destats.naccept) < 4
# adaptive time step
sol_i = solve(ode_i, alg)
sol_o = solve(ode_o, alg)
@test_broken sol_i.t sol_o.t
@test_broken sol_i.u sol_o.u
end
end

0 comments on commit 3d84dcb

Please sign in to comment.