-
-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mooncake Inside Problems #1152
Merged
ChrisRackauckas
merged 18 commits into
SciML:master
from
willtebbutt:wct/mooncake-inside-problems
Dec 5, 2024
Merged
Mooncake Inside Problems #1152
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
0e5a563
Initial pass at adjoints
willtebbutt 67044c1
Reduce duplication
willtebbutt d34fa69
Remove commented out code
willtebbutt a6d3f6f
Tidy up further
willtebbutt 0eed9ca
Tidy up furhter
willtebbutt 3c7bf42
More formatting
willtebbutt 59d8a22
Tidy up further
willtebbutt f5b5894
Fix formatting
willtebbutt 2b2ac55
Remove more redundancy
willtebbutt 48417e2
Remove code duplication further
willtebbutt 0733b50
Update docs
willtebbutt eaaba53
Make use of new Mooncake feature
willtebbutt 9935a2e
OOP tests
willtebbutt 9f54f79
Turn Mooncake into an extension
willtebbutt b9d0e1e
Tidy up Project toml
willtebbutt cf27ae4
Actually commit the extension
willtebbutt 3cf90d4
Bump patch version
willtebbutt 6ac50d1
Formatting
willtebbutt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "SciMLSensitivity" | ||
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" | ||
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"] | ||
version = "7.71.2" | ||
version = "7.72.0" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
|
@@ -42,6 +42,12 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" | |
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[weakdeps] | ||
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
|
||
[extensions] | ||
SciMLSensitivityMooncakeExt = "Mooncake" | ||
|
||
[compat] | ||
ADTypes = "1.9" | ||
Accessors = "0.1.36" | ||
|
@@ -71,6 +77,7 @@ LinearSolve = "2" | |
Lux = "1" | ||
Markdown = "1.10" | ||
ModelingToolkit = "9.42" | ||
Mooncake = "0.4.52" | ||
NLsolve = "4.5.1" | ||
NonlinearSolve = "3.0.1" | ||
Optimization = "4" | ||
|
@@ -110,6 +117,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" | |
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" | ||
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" | ||
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" | ||
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" | ||
|
@@ -123,4 +131,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" | |
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] | ||
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
module SciMLSensitivityMooncakeExt | ||
|
||
using SciMLSensitivity, Mooncake | ||
import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded | ||
|
||
function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t) | ||
dy_mem = zero(y) | ||
λ_mem = zero(y) | ||
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t) | ||
return cache, pf, λ_mem, dy_mem | ||
end | ||
|
||
function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ) | ||
cache, pf, λ_mem, dy_mem = paramjac_config | ||
λ_mem .= λ | ||
dy, _ = Mooncake.value_and_pullback!!(cache, λ_mem, pf, dy_mem, y, p, t) | ||
y_grad = cache.tangents[3] | ||
p_grad = cache.tangents[4] | ||
return dy, y_grad, p_grad | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
using SciMLSensitivity, SafeTestsets | ||
using Test, Pkg | ||
import Mooncake | ||
|
||
const GROUP = get(ENV, "GROUP", "All") | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ChrisRackauckas, it is probably better to refactor these hard-coded branches (e.g., define an interface function that other packages can overload). It would help
It might also help to switch to DI where possible to avoid duplicate glue code in the ecosystem. @gdalle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that really a high priority right now? How many more autograd packages are you going to write this year that will be useful?
Doesn't doesn't necessarily make sense. Most of the methods are used in the default method so they would be required to be loaded by default anyways?
That's the plan when it's able to handle this case well. Currently it's not able to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mooncake is getting a new forward mode (an attempt to improve ForwardDiff with GPU compatibility and fewer constraints; see here for more details), so @willtebbutt will likely need to modify these again in the near term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't that just require modifying https://github.com/SciML/SciMLSensitivity.jl/pull/1152/files#diff-1a15b4b5711133c125548ef7f1ca88f761bb124cffc8bfde8c13336968aaccd6R466 ? I don't see why that would touch this function and instead just dispatch on there.
I mean, if someone wants to do a refactor here that's perfectly fine. But I also don't see why it would be a high priority since it's not like new AD systems get added every year, and modifications to existing ones don't really touch this part of the code much. I would think the time would be better spent just trying to get DI up to speed than refactoring this old code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can discuss DI integration in #1040 if you want