From dd5c227fb8c16d7016005b55bff75a1733adf442 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 18 May 2024 11:27:39 -0700 Subject: [PATCH 1/5] WIP transition to package extensions --- Manifest.toml | 390 ++++++++++++++++++++++++++++++++++++ Project.toml | 13 +- ext/ADNLPModelsEnzymeExt.jl | 23 +++ ext/ADNLPModelsZygoteExt.jl | 121 +++++++++++ src/ADNLPModels.jl | 21 +- src/enzyme.jl | 21 -- src/zygote.jl | 119 ----------- 7 files changed, 564 insertions(+), 144 deletions(-) create mode 100644 Manifest.toml create mode 100644 ext/ADNLPModelsEnzymeExt.jl create mode 100644 ext/ADNLPModelsZygoteExt.jl delete mode 100644 src/enzyme.jl delete mode 100644 src/zygote.jl diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 00000000..d8f41f14 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,390 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.3" +manifest_format = "2.0" +project_hash = "bb1d07f7da969f7a910a69afe0616febf618ea69" + +[[deps.AMD]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse_jll"] +git-tree-sha1 = "45a1272e3f809d36431e57ab22703c6896b8908f" +uuid = "14f7f29c-3bd6-536c-9a0b-7339e30b5a3e" +version = "0.5.3" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.23.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.ColPack]] +deps = ["ColPack_jll", "LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "83a23545e7969d8b21fb85271b9cd04c8e09d08b" +uuid = "ffa27691-3a59-46ab-a8d4-551f45b8d401" +version = "0.3.0" + +[[deps.ColPack_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "2a518018a2b888ba529e944d34d4bd84b54d652d" +uuid = "f218ff0c-cb54-5151-80c4-c0f62c730ce6" +version = "0.3.0+0" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FastClosures]] +git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" +uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +version = "0.3.2" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.LDLFactorizations]] +deps = ["AMD", "LinearAlgebra", "SparseArrays", "Test"] +git-tree-sha1 = "70f582b446a1c3ad82cf87e62b878668beef9d13" +uuid = "40e66cde-538c-5869-a4ad-c39174c6795b" +version = "0.10.1" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LinearOperators]] +deps = ["FastClosures", "LDLFactorizations", "LinearAlgebra", "Printf", "Requires", "SparseArrays", "TimerOutputs"] +git-tree-sha1 = "f06df3a46255879cbccae1b5b6dcb16994c31be7" +uuid = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" +version = "2.7.0" +weakdeps = ["ChainRulesCore"] + + [deps.LinearOperators.extensions] + LinearOperatorsChainRulesCoreExt = "ChainRulesCore" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.27" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NLPModels]] +deps = ["FastClosures", "LinearAlgebra", "LinearOperators", "Printf", "SparseArrays"] +git-tree-sha1 = "2d110433ba53dcf225c1a958b1f91fda6cf3cead" +uuid = "a4795742-8479-5a88-8948-cc11e1c8c1a6" +version = "0.21.0" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.ReverseDiff]] +deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] +git-tree-sha1 = "cc6cd622481ea366bb9067859446a8b01d92b468" +uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +version = "1.15.3" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.3" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.23" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index f7b60817..937b541a 100644 --- a/Project.toml +++ b/Project.toml @@ -13,11 +13,22 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +ADNLPModelsEnzymeExt = "Enzyme" +ADNLPModelsSparseDiffToolsExt = "SparseDiffTools" +ADNLPModelsSymbolicsExt = "Symbolics" +ADNLPModelsZygoteExt = "Zygote" + [compat] ADTypes = "1.2.1" ForwardDiff = "0.9.0, 0.10.0" NLPModels = "0.18, 0.19, 0.20, 0.21" -Requires = "1" ReverseDiff = "1" SparseConnectivityTracer = "0.6.1" SparseMatrixColorings = "0.4.0" diff --git a/ext/ADNLPModelsEnzymeExt.jl b/ext/ADNLPModelsEnzymeExt.jl new file mode 100644 index 00000000..13ed9015 --- /dev/null +++ b/ext/ADNLPModelsEnzymeExt.jl @@ -0,0 +1,23 @@ +module ADNLPModelsEnzymeExt + +using Enzyme, ADNLPModels + +struct EnzymeADGradient <: ADNLPModels.ADBackend end + +function EnzymeADGradient( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + x0::AbstractVector = rand(nvar), + kwargs..., +) + return EnzymeADGradient() +end + +function ADNLPModels.gradient!(::EnzymeADGradient, g, f, x) + Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x) + return g +end + +end \ No newline at end of file diff --git a/ext/ADNLPModelsZygoteExt.jl b/ext/ADNLPModelsZygoteExt.jl new file mode 100644 index 00000000..9bb8707f --- /dev/null +++ b/ext/ADNLPModelsZygoteExt.jl @@ -0,0 +1,121 @@ +module ADNLPModelsZygoteExt + +using Zygote, ADNLPModels +import ADNLPModels: ADModel, AbstractADNLSModel, ADBackend, ImmutableADbackend + +struct ZygoteADGradient <: ADBackend end +struct ZygoteADJacobian <: ImmutableADbackend + nnzj::Int +end +struct ZygoteADHessian <: ImmutableADbackend + nnzh::Int +end +struct ZygoteADJprod <: ImmutableADbackend end +struct ZygoteADJtprod <: ImmutableADbackend end +# See https://fluxml.ai/Zygote.jl/latest/limitations/ +function get_immutable_c(nlp::ADModel) + function c(x; nnln = nlp.meta.nnln) + c = Zygote.Buffer(x, nnln) + nlp.c!(c, x) + return copy(c) + end + return c +end +get_c(nlp::ADModel, ::ImmutableADbackend) = get_immutable_c(nlp) + +function get_immutable_F(nls::AbstractADNLSModel) + function F(x; nequ = nls.nls_meta.nequ) + Fx = Zygote.Buffer(x, nequ) + nls.F!(Fx, x) + return copy(Fx) + end + return F +end +get_F(nls::AbstractADNLSModel, ::ImmutableADbackend) = get_immutable_F(nls) + +function ZygoteADGradient( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + return ZygoteADGradient() +end +function gradient(::ZygoteADGradient, f, x) + g = Zygote.gradient(f, x)[1] + return g === nothing ? zero(x) : g +end +function gradient!(::ZygoteADGradient, g, f, x) + _g = Zygote.gradient(f, x)[1] + g .= _g === nothing ? 0 : _g +end + +function ZygoteADJacobian( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + @assert nvar > 0 + nnzj = nvar * ncon + return ZygoteADJacobian(nnzj) +end +function jacobian(::ZygoteADJacobian, f, x) + return Zygote.jacobian(f, x)[1] +end + +function ZygoteADHessian( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + @assert nvar > 0 + nnzh = nvar * (nvar + 1) / 2 + return ZygoteADHessian(nnzh) +end +function hessian(b::ZygoteADHessian, f, x) + return jacobian( + ForwardDiffADJacobian(length(x), f, x0 = x), + x -> gradient(ZygoteADGradient(), f, x), + x, + ) +end + +function ZygoteADJprod( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + return ZygoteADJprod() +end +function Jprod!(::ZygoteADJprod, Jv, f, x, v, ::Val) + Jv .= vec(Zygote.jacobian(t -> f(x + t * v), 0)[1]) + return Jv +end + +function ZygoteADJtprod( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + return ZygoteADJtprod() +end +function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) + g = Zygote.gradient(x -> dot(f(x), v), x)[1] + if g === nothing + Jtv .= zero(x) + else + Jtv .= g + end + return Jtv +end + +end \ No newline at end of file diff --git a/src/ADNLPModels.jl b/src/ADNLPModels.jl index a50d1005..3fd6effd 100644 --- a/src/ADNLPModels.jl +++ b/src/ADNLPModels.jl @@ -11,7 +11,6 @@ using ForwardDiff, ReverseDiff # JSO using NLPModels -using Requires abstract type AbstractADNLPModel{T, S} <: AbstractNLPModel{T, S} end abstract type AbstractADNLSModel{T, S} <: AbstractNLSModel{T, S} end @@ -25,10 +24,26 @@ include("sparsity_pattern.jl") include("sparse_jacobian.jl") include("sparse_hessian.jl") +# Attempt to load some symbols from the Symbolics extensions +symbolics_ext = Base.get_extension(@__MODULE__, :ADNLPModelsSymbolicsExt) +if !isnothing(symbolics_ext) + SparseSymbolicsADJacobian = symbolics_ext.SparseSymbolicsADJacobian + SparseSymbolicsADHessian = symbolics_ext.SparseSymbolicsADHessian + SDTSparseADJacobian = symbolics_ext.SDTSparseADJacobian + + # These backends should only be included if the module is loaded + predefined_backend[:default][:jacobian_backend] = SparseADJacobian + predefined_backend[:default][:jacobian_residual_backend] = SparseADJacobian + predefined_backend[:optimized][:jacobian_backend] = SparseADJacobian + predefined_backend[:optimized][:jacobian_residual_backend] = SparseADJacobian + + predefined_backend[:default][:hessian_backend] = SparseADHessian + predefined_backend[:optimized][:hessian_backend] = SparseReverseADHessian +end + include("forward.jl") include("reverse.jl") -include("enzyme.jl") -include("zygote.jl") + include("predefined_backend.jl") include("nlp.jl") diff --git a/src/enzyme.jl b/src/enzyme.jl deleted file mode 100644 index db5133fe..00000000 --- a/src/enzyme.jl +++ /dev/null @@ -1,21 +0,0 @@ -struct EnzymeADGradient <: ADNLPModels.ADBackend end - -function EnzymeADGradient( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - x0::AbstractVector = rand(nvar), - kwargs..., -) - return EnzymeADGradient() -end - -@init begin - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - function ADNLPModels.gradient!(::EnzymeADGradient, g, f, x) - Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x) - return g - end - end -end diff --git a/src/zygote.jl b/src/zygote.jl deleted file mode 100644 index 63358a7e..00000000 --- a/src/zygote.jl +++ /dev/null @@ -1,119 +0,0 @@ -struct ZygoteADGradient <: ADBackend end -struct ZygoteADJacobian <: ImmutableADbackend - nnzj::Int -end -struct ZygoteADHessian <: ImmutableADbackend - nnzh::Int -end -struct ZygoteADJprod <: ImmutableADbackend end -struct ZygoteADJtprod <: ImmutableADbackend end - -@init begin - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - # See https://fluxml.ai/Zygote.jl/latest/limitations/ - function get_immutable_c(nlp::ADModel) - function c(x; nnln = nlp.meta.nnln) - c = Zygote.Buffer(x, nnln) - nlp.c!(c, x) - return copy(c) - end - return c - end - get_c(nlp::ADModel, ::ImmutableADbackend) = get_immutable_c(nlp) - - function get_immutable_F(nls::AbstractADNLSModel) - function F(x; nequ = nls.nls_meta.nequ) - Fx = Zygote.Buffer(x, nequ) - nls.F!(Fx, x) - return copy(Fx) - end - return F - end - get_F(nls::AbstractADNLSModel, ::ImmutableADbackend) = get_immutable_F(nls) - - function ZygoteADGradient( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., - ) - return ZygoteADGradient() - end - function gradient(::ZygoteADGradient, f, x) - g = Zygote.gradient(f, x)[1] - return g === nothing ? zero(x) : g - end - function gradient!(::ZygoteADGradient, g, f, x) - _g = Zygote.gradient(f, x)[1] - g .= _g === nothing ? 0 : _g - end - - function ZygoteADJacobian( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., - ) - @assert nvar > 0 - nnzj = nvar * ncon - return ZygoteADJacobian(nnzj) - end - function jacobian(::ZygoteADJacobian, f, x) - return Zygote.jacobian(f, x)[1] - end - - function ZygoteADHessian( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., - ) - @assert nvar > 0 - nnzh = nvar * (nvar + 1) / 2 - return ZygoteADHessian(nnzh) - end - function hessian(b::ZygoteADHessian, f, x) - return jacobian( - ForwardDiffADJacobian(length(x), f, x0 = x), - x -> gradient(ZygoteADGradient(), f, x), - x, - ) - end - - function ZygoteADJprod( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., - ) - return ZygoteADJprod() - end - function Jprod!(::ZygoteADJprod, Jv, f, x, v, ::Val) - Jv .= vec(Zygote.jacobian(t -> f(x + t * v), 0)[1]) - return Jv - end - - function ZygoteADJtprod( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., - ) - return ZygoteADJtprod() - end - function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) - g = Zygote.gradient(x -> dot(f(x), v), x)[1] - if g === nothing - Jtv .= zero(x) - else - Jtv .= g - end - return Jtv - end - end -end From 8f237c089e3e69d694f58c6dca8f61709d6532c2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 18 May 2024 17:53:09 -0700 Subject: [PATCH 2/5] Only use package extensions for AD backends --- Project.toml | 2 -- src/ADNLPModels.jl | 21 +++------------------ 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 937b541a..e475e309 100644 --- a/Project.toml +++ b/Project.toml @@ -21,8 +21,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ADNLPModelsEnzymeExt = "Enzyme" -ADNLPModelsSparseDiffToolsExt = "SparseDiffTools" -ADNLPModelsSymbolicsExt = "Symbolics" ADNLPModelsZygoteExt = "Zygote" [compat] diff --git a/src/ADNLPModels.jl b/src/ADNLPModels.jl index 3fd6effd..a50d1005 100644 --- a/src/ADNLPModels.jl +++ b/src/ADNLPModels.jl @@ -11,6 +11,7 @@ using ForwardDiff, ReverseDiff # JSO using NLPModels +using Requires abstract type AbstractADNLPModel{T, S} <: AbstractNLPModel{T, S} end abstract type AbstractADNLSModel{T, S} <: AbstractNLSModel{T, S} end @@ -24,26 +25,10 @@ include("sparsity_pattern.jl") include("sparse_jacobian.jl") include("sparse_hessian.jl") -# Attempt to load some symbols from the Symbolics extensions -symbolics_ext = Base.get_extension(@__MODULE__, :ADNLPModelsSymbolicsExt) -if !isnothing(symbolics_ext) - SparseSymbolicsADJacobian = symbolics_ext.SparseSymbolicsADJacobian - SparseSymbolicsADHessian = symbolics_ext.SparseSymbolicsADHessian - SDTSparseADJacobian = symbolics_ext.SDTSparseADJacobian - - # These backends should only be included if the module is loaded - predefined_backend[:default][:jacobian_backend] = SparseADJacobian - predefined_backend[:default][:jacobian_residual_backend] = SparseADJacobian - predefined_backend[:optimized][:jacobian_backend] = SparseADJacobian - predefined_backend[:optimized][:jacobian_residual_backend] = SparseADJacobian - - predefined_backend[:default][:hessian_backend] = SparseADHessian - predefined_backend[:optimized][:hessian_backend] = SparseReverseADHessian -end - include("forward.jl") include("reverse.jl") - +include("enzyme.jl") +include("zygote.jl") include("predefined_backend.jl") include("nlp.jl") From 37e9cc11d16dbc4200d2c8ceb03989472f6644c8 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Tue, 26 Nov 2024 14:16:32 -0600 Subject: [PATCH 3/5] Transition to package extensions --- Manifest.toml | 390 -------------------------------------------------- Project.toml | 12 +- 2 files changed, 7 insertions(+), 395 deletions(-) delete mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index d8f41f14..00000000 --- a/Manifest.toml +++ /dev/null @@ -1,390 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.3" -manifest_format = "2.0" -project_hash = "bb1d07f7da969f7a910a69afe0616febf618ea69" - -[[deps.AMD]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse_jll"] -git-tree-sha1 = "45a1272e3f809d36431e57ab22703c6896b8908f" -uuid = "14f7f29c-3bd6-536c-9a0b-7339e30b5a3e" -version = "0.5.3" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.ColPack]] -deps = ["ColPack_jll", "LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "83a23545e7969d8b21fb85271b9cd04c8e09d08b" -uuid = "ffa27691-3a59-46ab-a8d4-551f45b8d401" -version = "0.3.0" - -[[deps.ColPack_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "2a518018a2b888ba529e944d34d4bd84b54d652d" -uuid = "f218ff0c-cb54-5151-80c4-c0f62c730ce6" -version = "0.3.0+0" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" - -[[deps.FastClosures]] -git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" -uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -version = "0.3.2" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.FunctionWrappers]] -git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" -uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.3" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.LDLFactorizations]] -deps = ["AMD", "LinearAlgebra", "SparseArrays", "Test"] -git-tree-sha1 = "70f582b446a1c3ad82cf87e62b878668beef9d13" -uuid = "40e66cde-538c-5869-a4ad-c39174c6795b" -version = "0.10.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LinearOperators]] -deps = ["FastClosures", "LDLFactorizations", "LinearAlgebra", "Printf", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "f06df3a46255879cbccae1b5b6dcb16994c31be7" -uuid = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" -version = "2.7.0" -weakdeps = ["ChainRulesCore"] - - [deps.LinearOperators.extensions] - LinearOperatorsChainRulesCoreExt = "ChainRulesCore" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NLPModels]] -deps = ["FastClosures", "LinearAlgebra", "LinearOperators", "Printf", "SparseArrays"] -git-tree-sha1 = "2d110433ba53dcf225c1a958b1f91fda6cf3cead" -uuid = "a4795742-8479-5a88-8948-cc11e1c8c1a6" -version = "0.21.0" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.ReverseDiff]] -deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] -git-tree-sha1 = "cc6cd622481ea366bb9067859446a8b01d92b468" -uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -version = "1.15.3" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.23" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index e475e309..d8fd8dff 100644 --- a/Project.toml +++ b/Project.toml @@ -15,8 +15,6 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -25,9 +23,13 @@ ADNLPModelsZygoteExt = "Zygote" [compat] ADTypes = "1.2.1" -ForwardDiff = "0.9.0, 0.10.0" -NLPModels = "0.18, 0.19, 0.20, 0.21" +ForwardDiff = "0.10.0" +NLPModels = "0.21.3" ReverseDiff = "1" SparseConnectivityTracer = "0.6.1" SparseMatrixColorings = "0.4.0" -julia = "^1.6" +Enzyme = "0.13" +Zygote = "0.6" +LinearAlgebra = "1.10" +SparseArrays = "1.10" +julia = "1.10" From d1dd9a0739f1e3d99c6e836424001ec592a1033a Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 27 Nov 2024 10:41:21 -0600 Subject: [PATCH 4/5] Update the extensions --- ext/ADNLPModelsEnzymeExt.jl | 17 +----- ext/ADNLPModelsZygoteExt.jl | 106 +++++------------------------------- src/ADNLPModels.jl | 6 +- src/enzyme.jl | 12 ++++ src/zygote.jl | 83 ++++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 110 deletions(-) create mode 100644 src/enzyme.jl create mode 100644 src/zygote.jl diff --git a/ext/ADNLPModelsEnzymeExt.jl b/ext/ADNLPModelsEnzymeExt.jl index 13ed9015..60c1deb3 100644 --- a/ext/ADNLPModelsEnzymeExt.jl +++ b/ext/ADNLPModelsEnzymeExt.jl @@ -2,22 +2,9 @@ module ADNLPModelsEnzymeExt using Enzyme, ADNLPModels -struct EnzymeADGradient <: ADNLPModels.ADBackend end - -function EnzymeADGradient( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - x0::AbstractVector = rand(nvar), - kwargs..., -) - return EnzymeADGradient() -end - -function ADNLPModels.gradient!(::EnzymeADGradient, g, f, x) +function ADNLPModels.gradient!(::ADNLPModels.EnzymeADGradient, g, f, x) Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x) return g end -end \ No newline at end of file +end diff --git a/ext/ADNLPModelsZygoteExt.jl b/ext/ADNLPModelsZygoteExt.jl index 9bb8707f..8001ffdf 100644 --- a/ext/ADNLPModelsZygoteExt.jl +++ b/ext/ADNLPModelsZygoteExt.jl @@ -1,47 +1,7 @@ module ADNLPModelsZygoteExt using Zygote, ADNLPModels -import ADNLPModels: ADModel, AbstractADNLSModel, ADBackend, ImmutableADbackend -struct ZygoteADGradient <: ADBackend end -struct ZygoteADJacobian <: ImmutableADbackend - nnzj::Int -end -struct ZygoteADHessian <: ImmutableADbackend - nnzh::Int -end -struct ZygoteADJprod <: ImmutableADbackend end -struct ZygoteADJtprod <: ImmutableADbackend end -# See https://fluxml.ai/Zygote.jl/latest/limitations/ -function get_immutable_c(nlp::ADModel) - function c(x; nnln = nlp.meta.nnln) - c = Zygote.Buffer(x, nnln) - nlp.c!(c, x) - return copy(c) - end - return c -end -get_c(nlp::ADModel, ::ImmutableADbackend) = get_immutable_c(nlp) - -function get_immutable_F(nls::AbstractADNLSModel) - function F(x; nequ = nls.nls_meta.nequ) - Fx = Zygote.Buffer(x, nequ) - nls.F!(Fx, x) - return copy(Fx) - end - return F -end -get_F(nls::AbstractADNLSModel, ::ImmutableADbackend) = get_immutable_F(nls) - -function ZygoteADGradient( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., -) - return ZygoteADGradient() -end function gradient(::ZygoteADGradient, f, x) g = Zygote.gradient(f, x)[1] return g === nothing ? zero(x) : g @@ -51,63 +11,11 @@ function gradient!(::ZygoteADGradient, g, f, x) g .= _g === nothing ? 0 : _g end -function ZygoteADJacobian( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., -) - @assert nvar > 0 - nnzj = nvar * ncon - return ZygoteADJacobian(nnzj) -end -function jacobian(::ZygoteADJacobian, f, x) - return Zygote.jacobian(f, x)[1] -end - -function ZygoteADHessian( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., -) - @assert nvar > 0 - nnzh = nvar * (nvar + 1) / 2 - return ZygoteADHessian(nnzh) -end -function hessian(b::ZygoteADHessian, f, x) - return jacobian( - ForwardDiffADJacobian(length(x), f, x0 = x), - x -> gradient(ZygoteADGradient(), f, x), - x, - ) -end - -function ZygoteADJprod( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., -) - return ZygoteADJprod() -end function Jprod!(::ZygoteADJprod, Jv, f, x, v, ::Val) Jv .= vec(Zygote.jacobian(t -> f(x + t * v), 0)[1]) return Jv end -function ZygoteADJtprod( - nvar::Integer, - f, - ncon::Integer = 0, - c::Function = (args...) -> []; - kwargs..., -) - return ZygoteADJtprod() -end function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) g = Zygote.gradient(x -> dot(f(x), v), x)[1] if g === nothing @@ -118,4 +26,16 @@ function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) return Jtv end -end \ No newline at end of file +function jacobian(::ZygoteADJacobian, f, x) + return Zygote.jacobian(f, x)[1] +end + +function hessian(b::ZygoteADHessian, f, x) + return jacobian( + ForwardDiffADJacobian(length(x), f, x0 = x), + x -> gradient(ZygoteADGradient(), f, x), + x, + ) +end + +end diff --git a/src/ADNLPModels.jl b/src/ADNLPModels.jl index a50d1005..0b6e26ad 100644 --- a/src/ADNLPModels.jl +++ b/src/ADNLPModels.jl @@ -27,11 +27,13 @@ include("sparse_hessian.jl") include("forward.jl") include("reverse.jl") -include("enzyme.jl") -include("zygote.jl") include("predefined_backend.jl") include("nlp.jl") +# Extensions +include("enzyme.jl") +include("zygote.jl") + function ADNLPModel!(model::AbstractNLPModel; kwargs...) return if model.meta.nlin > 0 ADNLPModel!( diff --git a/src/enzyme.jl b/src/enzyme.jl new file mode 100644 index 00000000..b57b1cae --- /dev/null +++ b/src/enzyme.jl @@ -0,0 +1,12 @@ +struct EnzymeADGradient <: ADNLPModels.ADBackend end + +function EnzymeADGradient( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + x0::AbstractVector = rand(nvar), + kwargs..., +) + return EnzymeADGradient() +end diff --git a/src/zygote.jl b/src/zygote.jl new file mode 100644 index 00000000..5c1d8fb0 --- /dev/null +++ b/src/zygote.jl @@ -0,0 +1,83 @@ +struct ZygoteADGradient <: ADBackend end +struct ZygoteADJprod <: ImmutableADbackend end +struct ZygoteADJtprod <: ImmutableADbackend end +struct ZygoteADJacobian <: ImmutableADbackend + nnzj::Int +end +struct ZygoteADHessian <: ImmutableADbackend + nnzh::Int +end + +# See https://fluxml.ai/Zygote.jl/latest/limitations/ +function get_immutable_c(nlp::ADModel) + function c(x; nnln = nlp.meta.nnln) + c = Zygote.Buffer(x, nnln) + nlp.c!(c, x) + return copy(c) + end + return c +end +get_c(nlp::ADModel, ::ImmutableADbackend) = get_immutable_c(nlp) + +function get_immutable_F(nls::AbstractADNLSModel) + function F(x; nequ = nls.nls_meta.nequ) + Fx = Zygote.Buffer(x, nequ) + nls.F!(Fx, x) + return copy(Fx) + end + return F +end +get_F(nls::AbstractADNLSModel, ::ImmutableADbackend) = get_immutable_F(nls) + +function ZygoteADGradient( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + return ZygoteADGradient() +end + +function ZygoteADJprod( + nvar::Integer, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + return ZygoteADJprod() +end + +function ZygoteADJtprod( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + return ZygoteADJtprod() +end + +function ZygoteADJacobian( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + @assert nvar > 0 + nnzj = nvar * ncon + return ZygoteADJacobian(nnzj) +end + +function ZygoteADHessian( + nvar::Integer, + f, + ncon::Integer = 0, + c::Function = (args...) -> []; + kwargs..., +) + @assert nvar > 0 + nnzh = nvar * (nvar + 1) / 2 + return ZygoteADHessian(nnzh) +end From b97ff8dca39414c78bd09fe24ea39641f745cebf Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 27 Nov 2024 11:53:49 -0600 Subject: [PATCH 5/5] Update the extensions --- ext/ADNLPModelsZygoteExt.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/ADNLPModelsZygoteExt.jl b/ext/ADNLPModelsZygoteExt.jl index 8001ffdf..86f5101e 100644 --- a/ext/ADNLPModelsZygoteExt.jl +++ b/ext/ADNLPModelsZygoteExt.jl @@ -2,21 +2,21 @@ module ADNLPModelsZygoteExt using Zygote, ADNLPModels -function gradient(::ZygoteADGradient, f, x) +function gradient(::ADNLPModels.ZygoteADGradient, f, x) g = Zygote.gradient(f, x)[1] return g === nothing ? zero(x) : g end -function gradient!(::ZygoteADGradient, g, f, x) +function gradient!(::ADNLPModels.ZygoteADGradient, g, f, x) _g = Zygote.gradient(f, x)[1] g .= _g === nothing ? 0 : _g end -function Jprod!(::ZygoteADJprod, Jv, f, x, v, ::Val) +function Jprod!(::ADNLPModels.ZygoteADJprod, Jv, f, x, v, ::Val) Jv .= vec(Zygote.jacobian(t -> f(x + t * v), 0)[1]) return Jv end -function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) +function Jtprod!(::ADNLPModels.ZygoteADJtprod, Jtv, f, x, v, ::Val) g = Zygote.gradient(x -> dot(f(x), v), x)[1] if g === nothing Jtv .= zero(x) @@ -26,14 +26,14 @@ function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) return Jtv end -function jacobian(::ZygoteADJacobian, f, x) +function jacobian(::ADNLPModels.ZygoteADJacobian, f, x) return Zygote.jacobian(f, x)[1] end -function hessian(b::ZygoteADHessian, f, x) +function hessian(b::ADNLPModels.ZygoteADHessian, f, x) return jacobian( - ForwardDiffADJacobian(length(x), f, x0 = x), - x -> gradient(ZygoteADGradient(), f, x), + ADNLPModels.ForwardDiffADJacobian(length(x), f, x0 = x), + x -> gradient(ADNLPModels.ZygoteADGradient(), f, x), x, ) end