Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrap RHS to reduce latency with OrdinaryDiffEq.jl #1255

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

Conversation

ranocha
Copy link
Member

@ranocha ranocha commented Nov 9, 2022

Closes #1241

TODO

  • fix the wrapping and unwrapping
    • basic stuff
    • EulerAcousticsCouplingCallback
    • positivity-preserving limiters
    • plotting
  • decide how to handle type instability when accessing the semidiscretization in callbacks
    • function barriers for now, but this may need to be changed depending on extensive benchmarks
  • describe how to precompile low-storage methods for OrdinaryDiffEq.jl in our docs or link to their docs at least (set_preferences!(UUID("1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"), "PrecompileLowStorage" => true))
  • run benchmarks

Comment on lines 97 to 102
if integrator.f isa RHSWrapper
return integrator.f.semi
else
return integrator.p
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC (out-of-curiosity): Is this type stable? Or would it make sense to dispatch on the type of integrator.f, which is probably a type parameter of integrator? Or am I completely off track since we always return a semi and thus the return type is stable?

Copy link
Member Author

@ranocha ranocha Nov 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be type table since it should specialize on the type of integrator, which knows the type of integrator.f - but it may be worth checking to be sure - and running some benchmarks,of course
Edit: See below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrapping with losing type information is done at the level of OrdinaryDiffEq.jl. This means that accessing the semidiscretization like this is indeed not type stable. Thus, I can think of the following options

  • keep the code as it is in this PR, introducing type instabilities in callbacks accessing semi
  • introduce a function barrier in all callbacks after accessing semi to mitigate this problem
  • store the semidiscretization semi in the callbacks requiring it (if feasible, i.e., no full duplication, just a pointer)
  • keep the status quo without this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you been able to assess the impact on latency/compilation times with this change yet? Maybe we should try to test this on Roci for a more stable test environment?

If it does improve things (in a relevant way) in terms of TTFX, I think keeping the status quo is not desirable. In that case, I would try to go for the least complex, then the least intrusive solution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the post below: #1255 (comment)

@ranocha
Copy link
Member Author

ranocha commented Nov 9, 2022

Initial benchmarks

julia -e '
  @time using OrdinaryDiffEq;
  @time using Trixi;
  @time begin
    equations = LinearScalarAdvectionEquation2D((0.2, -0.7));
    solver = DGSEM(polydeg=3, surface_flux=flux_lax_friedrichs);
    mesh = TreeMesh((-1.0, -1.0), (1.0, 1.0), initial_refinement_level=4, n_cells_max=30_000);
    semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition_convergence_test, solver);
    ode = semidiscretize(semi, (0.0, 1.0))
  end;
  @time integrator = init(ode, SSPRK43());
'

and

julia -e '
  @time using OrdinaryDiffEq;
  @time using Trixi;
  @time begin
    equations = LinearScalarAdvectionEquation2D((0.2, -0.7));
    solver = DGSEM(polydeg=3, surface_flux=flux_lax_friedrichs);
    mesh = TreeMesh((-1.0, -1.0), (1.0, 1.0), initial_refinement_level=4, n_cells_max=30_000);
    semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition_convergence_test, solver);
    ode = semidiscretize(semi, (0.0, 1.0))
  end;
  @time solve(ode, SSPRK43());
'

yield the following preliminary results:

  • init time decreased from ca. 29.0 seconds to ca. 25.3 seconds
  • solve time decreased from ca. 29.4 seconds to ca. 25.9 seconds
  • using and semidiscretize times were unaffected

@sloede
Copy link
Member

sloede commented Nov 9, 2022

yield the following preliminary results:

* `init` time decreased from ca. 29.0 seconds to ca. 25.3 seconds

* `solve` time decreased from ca. 29.4 seconds to ca. 25.9 seconds

* `using` and `semidiscretize` times were unaffected

This sounds very promising! However, the loss of stable typing in the callbacks could potentially affect the runtime performance, couldn't it? Thus we probably should either run tests for these as well, or make sure we restore type stability through one of the methods you proposed.

@ranocha
Copy link
Member Author

ranocha commented Nov 9, 2022

For now, I switched everything to function barriers after extacting the semidisretization in a type-unstable way. Compared to the latest release of Trixi.jl, I get the following results with the current state of this PR.

  1. After using OrdinaryDiffEq, Trixi, runs of trixi_include("examples/tree_3d_dgsem/elixir_euler_taylor_green_vortex.jl", tspan=(0.0, 0.15)) after compilation take more time reported in the summary_callback in this PR:
 ────────────────────────────────────────────────────────────────────────────────────
              Trixi.jl                      Time                    Allocations      
                                   ───────────────────────   ────────────────────────
         Tot / % measured:              1.83s /  43.5%           65.8MiB /  11.6%    

 Section                   ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────────────────────
 rhs!                          91    472ms   59.2%  5.19ms   9.33KiB    0.1%     105B
   volume integral             91    284ms   35.6%  3.12ms     0.00B    0.0%    0.00B
   interface flux              91   82.9ms   10.4%   911μs     0.00B    0.0%    0.00B
   prolong2interfaces          91   56.0ms    7.0%   615μs     0.00B    0.0%    0.00B
   surface integral            91   34.9ms    4.4%   383μs     0.00B    0.0%    0.00B
   Jacobian                    91   9.59ms    1.2%   105μs     0.00B    0.0%    0.00B
   reset ∂u/∂t                 91   4.67ms    0.6%  51.3μs     0.00B    0.0%    0.00B
   ~rhs!~                      91    309μs    0.0%  3.40μs   9.33KiB    0.1%     105B
   prolong2boundaries          91   30.1μs    0.0%   330ns     0.00B    0.0%    0.00B
   prolong2mortars             91   20.9μs    0.0%   230ns     0.00B    0.0%    0.00B
   mortar flux                 91   9.39μs    0.0%   103ns     0.00B    0.0%    0.00B
   boundary flux               91   3.21μs    0.0%  35.3ns     0.00B    0.0%    0.00B
   source terms                91   2.16μs    0.0%  23.8ns     0.00B    0.0%    0.00B
 I/O                            3    260ms   32.5%  86.5ms   7.59MiB   99.4%  2.53MiB
   save solution                2    245ms   30.7%   122ms   7.54MiB   98.7%  3.77MiB
   ~I/O~                        3   14.7ms    1.8%  4.91ms   49.9KiB    0.6%  16.6KiB
   get element variables        2   16.6μs    0.0%  8.29μs   4.28KiB    0.1%  2.14KiB
   save mesh                    2    155ns    0.0%  77.5ns     0.00B    0.0%    0.00B
 analyze solution               2   61.4ms    7.7%  30.7ms   34.7KiB    0.4%  17.4KiB
 calculate dt                  19   4.21ms    0.5%   222μs     0.00B    0.0%    0.00B
 ────────────────────────────────────────────────────────────────────────────────────

  1.931855 seconds (1.54 M allocations: 103.443 MiB, 3.52% gc time, 56.16% compilation time: 95% of which was recompilation)

vs.

 ────────────────────────────────────────────────────────────────────────────────────
              Trixi.jl                      Time                    Allocations      
                                   ───────────────────────   ────────────────────────
         Tot / % measured:              831ms /  94.8%           9.08MiB /  84.0%    

 Section                   ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────────────────────
 rhs!                          91    459ms   58.2%  5.04ms   9.33KiB    0.1%     105B
   volume integral             91    273ms   34.7%  3.00ms     0.00B    0.0%    0.00B
   interface flux              91   82.1ms   10.4%   902μs     0.00B    0.0%    0.00B
   prolong2interfaces          91   54.5ms    6.9%   599μs     0.00B    0.0%    0.00B
   surface integral            91   34.7ms    4.4%   382μs     0.00B    0.0%    0.00B
   Jacobian                    91   9.24ms    1.2%   101μs     0.00B    0.0%    0.00B
   reset ∂u/∂t                 91   4.64ms    0.6%  51.0μs     0.00B    0.0%    0.00B
   ~rhs!~                      91    312μs    0.0%  3.42μs   9.33KiB    0.1%     105B
   prolong2boundaries          91   26.1μs    0.0%   287ns     0.00B    0.0%    0.00B
   prolong2mortars             91   23.4μs    0.0%   257ns     0.00B    0.0%    0.00B
   mortar flux                 91   14.6μs    0.0%   161ns     0.00B    0.0%    0.00B
   boundary flux               91   2.43μs    0.0%  26.7ns     0.00B    0.0%    0.00B
   source terms                91   2.02μs    0.0%  22.2ns     0.00B    0.0%    0.00B
 I/O                            3    264ms   33.4%  87.8ms   7.59MiB   99.4%  2.53MiB
   save solution                2    247ms   31.4%   124ms   7.54MiB   98.7%  3.77MiB
   ~I/O~                        3   16.4ms    2.1%  5.47ms   49.9KiB    0.6%  16.6KiB
   get element variables        2   18.6μs    0.0%  9.32μs   4.28KiB    0.1%  2.14KiB
   save mesh                    2    220ns    0.0%   110ns     0.00B    0.0%    0.00B
 analyze solution               2   61.2ms    7.8%  30.6ms   34.7KiB    0.4%  17.4KiB
 calculate dt                  19   4.23ms    0.5%   223μs     0.00B    0.0%    0.00B
 ────────────────────────────────────────────────────────────────────────────────────

  2.274410 seconds (2.19 M allocations: 138.377 MiB, 63.22% compilation time: 98% of which was recompilation)

in the latest release (with Julia 1.8.2).

  1. However, the timings of the individual parts do not differ significantly (in particular rhs!, analyze solution, and calculate dt).
  2. The times of @time solve(ode, SSPRK43(), callback=callbacks, save_everystep=false); summary_callback() after running the trixi_include above and after compilation are not changed significantly,
  0.716991 seconds (3.14 k allocations: 25.265 MiB)
 ────────────────────────────────────────────────────────────────────────────────────
              Trixi.jl                      Time                    Allocations      
                                   ───────────────────────   ────────────────────────
         Tot / % measured:              702ms /  95.7%           13.9MiB /  54.7%    

 Section                   ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────────────────────
 rhs!                          70    359ms   53.4%  5.13ms   9.33KiB    0.1%     136B
   volume integral             70    217ms   32.3%  3.10ms     0.00B    0.0%    0.00B
   interface flux              70   62.5ms    9.3%   893μs     0.00B    0.0%    0.00B
   prolong2interfaces          70   42.6ms    6.3%   609μs     0.00B    0.0%    0.00B
   surface integral            70   24.9ms    3.7%   356μs     0.00B    0.0%    0.00B
   Jacobian                    70   7.16ms    1.1%   102μs     0.00B    0.0%    0.00B
   reset ∂u/∂t                 70   4.54ms    0.7%  64.8μs     0.00B    0.0%    0.00B
   ~rhs!~                      70    232μs    0.0%  3.32μs   9.33KiB    0.1%     136B
   prolong2mortars             70   15.7μs    0.0%   224ns     0.00B    0.0%    0.00B
   prolong2boundaries          70   15.2μs    0.0%   217ns     0.00B    0.0%    0.00B
   mortar flux                 70   9.47μs    0.0%   135ns     0.00B    0.0%    0.00B
   source terms                70   4.12μs    0.0%  58.8ns     0.00B    0.0%    0.00B
   boundary flux               70   2.57μs    0.0%  36.7ns     0.00B    0.0%    0.00B
 I/O                            3    249ms   37.0%  82.9ms   7.54MiB   99.4%  2.51MiB
   save solution                2    249ms   37.0%   124ms   7.53MiB   99.3%  3.77MiB
   get element variables        2   37.3μs    0.0%  18.6μs   4.28KiB    0.1%  2.14KiB
   ~I/O~                        3   15.2μs    0.0%  5.07μs   3.20KiB    0.0%  1.07KiB
   save mesh                    2    271ns    0.0%   136ns     0.00B    0.0%    0.00B
 analyze solution               2   64.5ms    9.6%  32.2ms   34.7KiB    0.4%  17.4KiB
 ────────────────────────────────────────────────────────────────────────────────────

vs.

  0.731163 seconds (3.08 k allocations: 25.247 MiB)
 ────────────────────────────────────────────────────────────────────────────────────
              Trixi.jl                      Time                    Allocations      
                                   ───────────────────────   ────────────────────────
         Tot / % measured:              716ms /  95.7%           13.8MiB /  54.8%    

 Section                   ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────────────────────
 rhs!                          70    372ms   54.2%  5.31ms   9.33KiB    0.1%     136B
   volume integral             70    221ms   32.2%  3.16ms     0.00B    0.0%    0.00B
   interface flux              70   65.9ms    9.6%   942μs     0.00B    0.0%    0.00B
   prolong2interfaces          70   45.5ms    6.6%   649μs     0.00B    0.0%    0.00B
   surface integral            70   26.9ms    3.9%   385μs     0.00B    0.0%    0.00B
   Jacobian                    70   7.38ms    1.1%   105μs     0.00B    0.0%    0.00B
   reset ∂u/∂t                 70   4.94ms    0.7%  70.5μs     0.00B    0.0%    0.00B
   ~rhs!~                      70    276μs    0.0%  3.94μs   9.33KiB    0.1%     136B
   prolong2boundaries          70   20.8μs    0.0%   297ns     0.00B    0.0%    0.00B
   prolong2mortars             70   18.0μs    0.0%   258ns     0.00B    0.0%    0.00B
   mortar flux                 70   8.81μs    0.0%   126ns     0.00B    0.0%    0.00B
   boundary flux               70   1.90μs    0.0%  27.1ns     0.00B    0.0%    0.00B
   source terms                70   1.73μs    0.0%  24.7ns     0.00B    0.0%    0.00B
 I/O                            3    249ms   36.3%  82.9ms   7.54MiB   99.4%  2.51MiB
   save solution                2    249ms   36.3%   124ms   7.53MiB   99.3%  3.77MiB
   get element variables        2   14.2μs    0.0%  7.11μs   4.28KiB    0.1%  2.14KiB
   ~I/O~                        3   8.30μs    0.0%  2.77μs   3.20KiB    0.0%  1.07KiB
   save mesh                    2    177ns    0.0%  88.5ns     0.00B    0.0%    0.00B
 analyze solution               2   65.2ms    9.5%  32.6ms   34.7KiB    0.4%  17.4KiB
 ────────────────────────────────────────────────────────────────────────────────────

in the latest release.

@ranocha ranocha requested a review from sloede November 9, 2022 15:52
@ranocha
Copy link
Member Author

ranocha commented Nov 9, 2022

@sloede I think this is ready for a first review. It would be great to get some help running tests and benchmarks in the wild with this.

@ranocha
Copy link
Member Author

ranocha commented Nov 10, 2022

There is a huge problem with plotting. If we use this, the ODE problem (and solution) will not carry any type information from Trixi.jl anymore. That means dispatching on the ODE solution type for our plotting will be type piracy and overwrite existing plotting methods. I guess this basically settles this - we cannot use this approach 😞
Do you have any other idea, @sloede?

@sloede
Copy link
Member

sloede commented Nov 10, 2022

There is a huge problem with plotting. If we use this, the ODE problem (and solution) will not carry any type information from Trixi.jl anymore. That means dispatching on the ODE solution type for our plotting will be type piracy and overwrite existing plotting methods. I guess this basically settles this - we cannot use this approach 😞 Do you have any other idea, @sloede?

Just such that I understand it correctly: With the change proposed here, plot(sol) will not work anymore (but plot(PlotData2D(sol)) will)?

@ranocha
Copy link
Member Author

ranocha commented Nov 10, 2022

Basically yes - we could make this work by breaking plot(sol), which I really dislike 😞

@sloede
Copy link
Member

sloede commented Nov 10, 2022

Basically yes - we could make this work by breaking plot(sol), which I really dislike 😞

Hm. I see three options

  • Break plot(sol)
  • Introduce a new function (tplot? trixi_plot?) to allow quick'n'dirty plotting of sol, e.g., tplot(sol), but it will not be the same as before
  • Suffer the latency penalty like before

Maybe we should collect the changes in the runtime numbers and put this up for discussion at a meeting? I also tend to dislike breaking plot(sol) the most, simply because it is vital part of the "easy to use" appeal...

@ranocha
Copy link
Member Author

ranocha commented Nov 11, 2022

Maybe we should collect the changes in the runtime numbers and put this up for discussion at a meeting?

Sounds reasonable. Any help is welcome 🙂

@sloede
Copy link
Member

sloede commented Dec 2, 2022

Would it possibly help if we, instead of storing the full semi in the parameters of sol, only store a sentinel (singleton?) type? That is, just enough to be able to continue to use plot(sol), but not the full semi anymore, such that it can be efficiently precompiled (as it does not change with each configuration anymore). Or would this just cause other problems?

@ranocha
Copy link
Member Author

ranocha commented Dec 2, 2022

Would it possibly help if we, instead of storing the full semi in the parameters of sol, only store a sentinel (singleton?) type? That is, just enough to be able to continue to use plot(sol), but not the full semi anymore, such that it can be efficiently precompiled (as it does not change with each configuration anymore). Or would this just cause other problems?

This could indeed work but would also require that we perform the precompilation of ODE solvers from OrdinaryDiffEq.jl within Trixi.jl. In particular, we can't benefit from their precompilation directly.

@ranocha ranocha mentioned this pull request Dec 2, 2022
3 tasks
@ranocha ranocha mentioned this pull request Dec 8, 2022
4 tasks
@ranocha ranocha mentioned this pull request Dec 24, 2022
10 tasks
@sloede
Copy link
Member

sloede commented Jan 23, 2023

Would it possibly help if we, instead of storing the full semi in the parameters of sol, only store a sentinel (singleton?) type? That is, just enough to be able to continue to use plot(sol), but not the full semi anymore, such that it can be efficiently precompiled (as it does not change with each configuration anymore). Or would this just cause other problems?

This could indeed work but would also require that we perform the precompilation of ODE solvers from OrdinaryDiffEq.jl within Trixi.jl. In particular, we can't benefit from their precompilation directly.

Could this be fixed with upstream support? E.g., if they added a type parameter in an appropriate location such that it will not break precompilation benefits but allow us to to "tag" the return value somehow? I'm just brainstorming, probably what I'm trying to do is fundamentally impossible with the type system 😅

@ranocha
Copy link
Member Author

ranocha commented Jan 23, 2023

Yeah, I guess it's impossible. They would need to precompile for something that is created by Trixi.jl - which is impossible since they don't depend on us.

@ranocha ranocha mentioned this pull request Nov 10, 2023
8 tasks
@ranocha ranocha mentioned this pull request Feb 23, 2024
2 tasks
@thomvet
Copy link

thomvet commented Mar 29, 2024

I am not sure, but is this something where the relatively new https://github.com/SciML/SciMLStructures.jl could help in the long run?

@ranocha
Copy link
Member Author

ranocha commented Mar 29, 2024

I am not sure, but is this something where the relatively new https://github.com/SciML/SciMLStructures.jl could help in the long run?

I'm not sure.

@ranocha ranocha mentioned this pull request Jul 1, 2024
10 tasks
@ranocha ranocha mentioned this pull request Oct 10, 2024
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrap semidiscretization in RHS passed to OrdinaryDiffEq.jl
3 participants