From 782b3b9b2174e6f256b8cb078f817f190684aa75 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Wed, 13 Mar 2024 15:17:49 -0500 Subject: [PATCH] Enzyme extension --- Project.toml | 2 ++ ext/JuMPEnzymeExt.jl | 76 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 ext/JuMPEnzymeExt.jl diff --git a/Project.toml b/Project.toml index f71308cf950..d00dd88efa9 100644 --- a/Project.toml +++ b/Project.toml @@ -15,9 +15,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [extensions] JuMPDimensionalDataExt = "DimensionalData" +JuMPEnzymeExt = "Enzyme" [compat] DimensionalData = "0.24, 0.25, 0.26.2" diff --git a/ext/JuMPEnzymeExt.jl b/ext/JuMPEnzymeExt.jl new file mode 100644 index 00000000000..49b1e31f6b5 --- /dev/null +++ b/ext/JuMPEnzymeExt.jl @@ -0,0 +1,76 @@ +module JuMPEnzymeExt + +using Enzyme +using JuMP + +function jump_operator(f::Function) + @inline function f!(y, x...) + y[1] = f(x...) + end + function gradient!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N} + y = zeros(T,1) + ry = ones(T,1) + rx = ntuple(N) do i + Active(x[i]) + end + g .= autodiff(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end] + return nothing + end + + function gradient_deferred!(g, y, ry, rx...) + g .= autodiff_deferred(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end] + return nothing + end + + function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N} + y = zeros(T,1) + dy = ntuple(N) do i + ones(1) + end + g = zeros(T,N) + dg = ntuple(N) do i + zeros(T,N) + end + ry = ones(1) + dry = ntuple(N) do i + zeros(T,1) + end + rx = ntuple(N) do i + Active(x[i]) + end + + args = ntuple(N) do i + drx = ntuple(N) do j + if i == j + Active(one(T)) + else + Active(zero(T)) + end + end + BatchDuplicated(rx[i], drx) + end + autodiff(Forward, gradient_deferred!, Const, BatchDuplicated(g,dg), BatchDuplicated(y,dy), BatchDuplicated(ry, dry), args...) + for i in 1:N + for j in 1:N + if i <= j + H[j,i] = dg[j][i] + end + end + end + return nothing + end + + return gradient!, hessian! +end + +function JuMP.add_nonlinear_operator( + model::GenericModel, + dim::Int, + f::Function; + name::Symbol = Symbol(f), +) + gradient, hessian = jump_operator(f) + MOI.set(model, MOI.UserDefinedFunction(name, dim), tuple(f, gradient, hessian)) + return NonlinearOperator(f, name) +end +end