diff --git a/src/perform_step/rkn_perform_step.jl b/src/perform_step/rkn_perform_step.jl index 2a247ed46c..6995e655b9 100644 --- a/src/perform_step/rkn_perform_step.jl +++ b/src/perform_step/rkn_perform_step.jl @@ -118,7 +118,7 @@ end f.f1(k.x[1], du, u, p, t + dt) f.f2(k.x[2], du, u, p, t + dt) - integrator.stats.nf += 1 + integrator.stats.nf += 4 integrator.stats.nf2 += 1 end @@ -159,7 +159,7 @@ end abar76 * k6) # abar72 = 0 k7 = f.f1(kdu, ku, p, t + dt * c7) - u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # no b6, b7 + u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # no b6, b7 du = duprev + dt * (bbar1 * k1 + bbar3 * k3 + bbar4 * k4 + bbar5 * k5 + bbar6 * k6) # no b2, b7 integrator.u = ArrayPartition((du, u)) @@ -223,7 +223,6 @@ end @.. broadcast=false ku=uprev + dt * (c7 * duprev + dt * (a71 * k1 + a73 * k3 + a74 * k4 + a75 * k5)) # a72 = a76 = 0 - @.. broadcast=false kdu=duprev + dt * (abar71 * k1 + abar73 * k3 + abar74 * k4 + abar75 * k5 + abar76 * k6) # abar72 = 0 @@ -242,12 +241,12 @@ end if integrator.opts.adaptive duhat, uhat = utilde.x dtsq = dt^2 - @.. broadcast=false uhat = dtsq * - (btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + - btilde5 * k5) - @.. broadcast=false duhat = dt * - (bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + - bptilde5 * k5 + bptilde6 * k6 + bptilde7 * k7) + @.. broadcast=false uhat=dtsq * + (btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + + btilde5 * k5) + @.. broadcast=false duhat=dt * + (bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + + bptilde5 * k5 + bptilde6 * k6 + bptilde7 * k7) calculate_residuals!(atmp, utilde, integrator.uprev, integrator.u, integrator.opts.abstol, integrator.opts.reltol, diff --git a/test/algconvergence/partitioned_methods_tests.jl b/test/algconvergence/partitioned_methods_tests.jl index b8943546ad..3e1a88f9f1 100644 --- a/test/algconvergence/partitioned_methods_tests.jl +++ b/test/algconvergence/partitioned_methods_tests.jl @@ -367,3 +367,175 @@ sol = solve(prob, ERKN5(), reltol = 1e-8) @test length(sol.u) < 34 sol = solve(prob, ERKN7(), reltol = 1e-8) @test length(sol.u) < 38 + +# Compare in-place and out-of-place versions +function damped_oscillator(du, u, p, t) + return -u - 0.5 * du +end +function damped_oscillator!(ddu, du, u, p, t) + @. ddu = -u - 0.5 * du + return nothing +end +@testset "in-place vs. out-of-place" begin + ode_i = SecondOrderODEProblem(damped_oscillator!, + [0.0], [1.0], + (0.0, 10.0)) + ode_o = SecondOrderODEProblem(damped_oscillator, + [0.0], [1.0], + (0.0, 10.0)) + + @testset "Nystrom4" begin + alg = Nystrom4() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, dt = dt) + sol_o = solve(ode_o, alg, dt = dt) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 4 * sol_i.destats.naccept) < 4 + end + + @testset "FineRKN5" begin + alg = FineRKN5() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, adaptive = false, dt = dt) + sol_o = solve(ode_o, alg, adaptive = false, dt = dt) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 7 * sol_i.destats.naccept) < 4 + # adaptive time step + sol_i = solve(ode_i, alg) + sol_o = solve(ode_o, alg) + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u + end + + @testset "DPRKN4" begin + alg = DPRKN4() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, adaptive = false, dt = dt) + sol_o = solve(ode_o, alg, adaptive = false, dt = dt) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 4 * sol_i.destats.naccept) < 4 + # adaptive time step + sol_i = solve(ode_i, alg) + sol_o = solve(ode_o, alg) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + end + + @testset "DPRKN5" begin + alg = DPRKN5() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, adaptive = false, dt = dt) + sol_o = solve(ode_o, alg, adaptive = false, dt = dt) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 6 * sol_i.destats.naccept) < 4 + # adaptive time step + sol_i = solve(ode_i, alg) + sol_o = solve(ode_o, alg) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + end + + @testset "DPRKN6" begin + alg = DPRKN6() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, adaptive = false, dt = dt) + sol_o = solve(ode_o, alg, adaptive = false, dt = dt) + @test sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 6 * sol_i.destats.naccept) < 4 + # adaptive time step + sol_i = solve(ode_i, alg) + sol_o = solve(ode_o, alg) + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u + end + + @testset "DPRKN6FM" begin + alg = DPRKN6FM() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, adaptive = false, dt = dt) + sol_o = solve(ode_o, alg, adaptive = false, dt = dt) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 6 * sol_i.destats.naccept) < 4 + # adaptive time step + sol_i = solve(ode_i, alg) + sol_o = solve(ode_o, alg) + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u + end + + @testset "DPRKN8" begin + alg = DPRKN8() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, adaptive = false, dt = dt) + sol_o = solve(ode_o, alg, adaptive = false, dt = dt) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 9 * sol_i.destats.naccept) < 4 + # adaptive time step + sol_i = solve(ode_i, alg) + sol_o = solve(ode_o, alg) + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u + end + + @testset "DPRKN12" begin + alg = DPRKN12() + dt = 0.5 + # fixed time step + sol_i = solve(ode_i, alg, adaptive = false, dt = dt) + sol_o = solve(ode_o, alg, adaptive = false, dt = dt) + @test sol_i.t ≈ sol_o.t + @test sol_i.u ≈ sol_o.u + @test sol_i.destats.nf == sol_o.destats.nf + @test sol_i.destats.nf2 == sol_o.destats.nf2 + @test sol_i.destats.naccept == sol_o.destats.naccept + @test 19 <= sol_i.destats.naccept <= 21 + @test abs(sol_i.destats.nf - 17 * sol_i.destats.naccept) < 4 + # adaptive time step + sol_i = solve(ode_i, alg) + sol_o = solve(ode_o, alg) + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u + end +end