From da0e36010c3ad8ea52c0386338079a47279c1456 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Fri, 8 Mar 2024 12:16:01 -0600 Subject: [PATCH 1/5] Add JuMP extension --- Project.toml | 3 ++ ext/EnzymeJuMPExt.jl | 80 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 ext/EnzymeJuMPExt.jl diff --git a/Project.toml b/Project.toml index 8f295d1e42..ec91fb0a26 100644 --- a/Project.toml +++ b/Project.toml @@ -17,9 +17,11 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] +JuMP = "4076af6c-e467-56ae-b986-b466b2749572" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] +EnzymeJuMPExt = "JuMP" EnzymeSpecialFunctionsExt = "SpecialFunctions" [compat] @@ -27,6 +29,7 @@ CEnum = "0.4, 0.5" EnzymeCore = "0.7" Enzyme_jll = "0.0.102" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25" +JuMP = "1" LLVM = "6.1" ObjectFile = "0.4" Preferences = "1.4" diff --git a/ext/EnzymeJuMPExt.jl b/ext/EnzymeJuMPExt.jl new file mode 100644 index 0000000000..a7077c047d --- /dev/null +++ b/ext/EnzymeJuMPExt.jl @@ -0,0 +1,80 @@ +module EnzymeJuMPExt + +using Enzyme +using JuMP + +function jump_operator(f) + @inline function f!(y, x...) + y[1] = f(x...) + end + function gradient!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N} + y = zeros(1) + ry = ones(1) + rx = ntuple(N) do i + Active(x[i]) + end + g .= autodiff(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end] + return nothing + end + + function gradient_deferred!(g, y, ry, rx...) + g .= autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end] + return nothing + end + + function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N} + y = zeros(1) + dg = zeros(N) + y[1] = 0.0 + dy = ones(1) + ry = ones(1) + dry = zeros(1) + g = zeros(N) + dg = zeros(N) + + rx = ntuple(N) do i + Active(x[i]) + end + + for j in 1:N + y[1] = 0.0 + dy[1] = 1.0 + ry[1] = 1.0 + dry[1] = 0.0 + drx = ntuple(N) do i + if i == j + Active(one(T)) + else + Active(zero(T)) + end + end + tdrx= ntuple(N) do i + Duplicated(rx[i], drx[i]) + end + fill!(dg, 0.0) + fill!(g, 0.0) + autodiff(Forward, gradient_deferred!, Const, Duplicated(g,dg), Duplicated(y,dy), Duplicated(ry, dry), tdrx...) + for i in 1:N + if i <= j + H[j,i] = dg[i] + end + end + end + + return nothing + end + return gradient!, hessian! +end + +function JuMP.add_nonlinear_operator( + model::GenericModel, + dim::Int, + f::Function; + name::Symbol = Symbol(f), +) + gradient, hessian = jump_operator(f) + @show tuple(f, gradient, hessian) + MOI.set(model, MOI.UserDefinedFunction(name, dim), tuple(f, gradient, hessian)) + return NonlinearOperator(f, name) +end +end From 11ea4e51ce159bb07e753cf302fb47be581a6599 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Fri, 8 Mar 2024 12:23:23 -0600 Subject: [PATCH 2/5] Fix --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index ec91fb0a26..275c47152e 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,6 @@ CEnum = "0.4, 0.5" EnzymeCore = "0.7" Enzyme_jll = "0.0.102" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25" -JuMP = "1" LLVM = "6.1" ObjectFile = "0.4" Preferences = "1.4" From 3f6881d4aa846f8fd4e2aa82f393091e51611477 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Fri, 8 Mar 2024 13:39:51 -0600 Subject: [PATCH 3/5] Batched mode for JuMP --- ext/EnzymeJuMPExt.jl | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/ext/EnzymeJuMPExt.jl b/ext/EnzymeJuMPExt.jl index a7077c047d..32e478c9e5 100644 --- a/ext/EnzymeJuMPExt.jl +++ b/ext/EnzymeJuMPExt.jl @@ -24,39 +24,36 @@ function jump_operator(f) function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N} y = zeros(1) - dg = zeros(N) - y[1] = 0.0 - dy = ones(1) - ry = ones(1) - dry = zeros(1) + dy = ntuple(N) do i + ones(1) + end g = zeros(N) - dg = zeros(N) - + dg = ntuple(N) do i + zeros(N) + end + ry = ones(1) + dry = ntuple(N) do i + zeros(1) + end rx = ntuple(N) do i Active(x[i]) end - for j in 1:N - y[1] = 0.0 - dy[1] = 1.0 - ry[1] = 1.0 - dry[1] = 0.0 - drx = ntuple(N) do i + args = ntuple(N) do i + drx = ntuple(N) do j if i == j Active(one(T)) else Active(zero(T)) end end - tdrx= ntuple(N) do i - Duplicated(rx[i], drx[i]) - end - fill!(dg, 0.0) - fill!(g, 0.0) - autodiff(Forward, gradient_deferred!, Const, Duplicated(g,dg), Duplicated(y,dy), Duplicated(ry, dry), tdrx...) - for i in 1:N + BatchDuplicated(rx[i], drx) + end + autodiff(Forward, gradient_deferred!, Const, BatchDuplicated(g,dg), BatchDuplicated(y,dy), BatchDuplicated(ry, dry), args...) + for i in 1:N + for j in 1:N if i <= j - H[j,i] = dg[i] + H[j,i] = dg[j][i] end end end @@ -73,7 +70,6 @@ function JuMP.add_nonlinear_operator( name::Symbol = Symbol(f), ) gradient, hessian = jump_operator(f) - @show tuple(f, gradient, hessian) MOI.set(model, MOI.UserDefinedFunction(name, dim), tuple(f, gradient, hessian)) return NonlinearOperator(f, name) end From e4dba05d5392beb35957ff207e02381da38b4b69 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Fri, 8 Mar 2024 13:44:44 -0600 Subject: [PATCH 4/5] No primal needed --- ext/EnzymeJuMPExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/EnzymeJuMPExt.jl b/ext/EnzymeJuMPExt.jl index 32e478c9e5..a8e7e676ab 100644 --- a/ext/EnzymeJuMPExt.jl +++ b/ext/EnzymeJuMPExt.jl @@ -13,12 +13,12 @@ function jump_operator(f) rx = ntuple(N) do i Active(x[i]) end - g .= autodiff(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end] + g .= autodiff(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end] return nothing end function gradient_deferred!(g, y, ry, rx...) - g .= autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end] + g .= autodiff_deferred(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end] return nothing end From 6669d87a9d40bb9e9cccb09c5ad1b3cfba7c86d1 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Fri, 8 Mar 2024 13:58:38 -0600 Subject: [PATCH 5/5] Add missing types args --- ext/EnzymeJuMPExt.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/EnzymeJuMPExt.jl b/ext/EnzymeJuMPExt.jl index a8e7e676ab..d0bf5ac193 100644 --- a/ext/EnzymeJuMPExt.jl +++ b/ext/EnzymeJuMPExt.jl @@ -3,13 +3,13 @@ module EnzymeJuMPExt using Enzyme using JuMP -function jump_operator(f) +function jump_operator(f::Function) @inline function f!(y, x...) y[1] = f(x...) end function gradient!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N} - y = zeros(1) - ry = ones(1) + y = zeros(T,1) + ry = ones(T,1) rx = ntuple(N) do i Active(x[i]) end @@ -23,17 +23,17 @@ function jump_operator(f) end function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N} - y = zeros(1) + y = zeros(T,1) dy = ntuple(N) do i ones(1) end - g = zeros(N) + g = zeros(T,N) dg = ntuple(N) do i - zeros(N) + zeros(T,N) end ry = ones(1) dry = ntuple(N) do i - zeros(1) + zeros(T,1) end rx = ntuple(N) do i Active(x[i]) @@ -57,9 +57,9 @@ function jump_operator(f) end end end - return nothing end + return gradient!, hessian! end