-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward over reverse for variadic function #1336
Comments
make |
For the latter you cant use a duplicated of a float, it won't do what you expect (which is what's happening in you case. You have to make those active |
I can't because I need to apply forward over the reverse pass. Or at least I got an error indicating that. Do you know whether there is a way to compute second-order derivatives of variadic functions with active number arguments using Enzyme? Or unclear? I read the type unstable part in the documentation, but it did not error this time with the proper warning. As far as I understand, not every type unstable code fails. And I don't pass in temporary storage. |
Ok. I think I see. I have to move the tuples out. Thanks. |
Applying forward over reverse shouldn't be the issue. For any reverse call (independent of how called, either directly, or in forward over reverse, etc) active is required for non-mutable state, duplicated for mutable. So if you got an error, it may help to see what it said? |
Okay. Finally got that second-order. Sorry for removing previous posts. using Enzyme
# Rosenbrock
@inline function f(x...)
(1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
end
@inline function f!(y, x...)
y[1] = f(x...)
end
x = (1.0, 2.0)
y = zeros(1)
f!(y,x...)
y[1] = 0.0
ry = ones(1)
g = zeros(2)
rx = ntuple(2) do i
Active(x[i])
end
function gradient!(g, y, ry, rx...)
g .= autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
return nothing
end
gradient!(g, y,ry,rx...)
# FoR
y[1] = 0.0
dy = ones(1)
ry[1] = 1.0
dry = zeros(1)
drx = ntuple(2) do i
Active(one(Float64))
end
tdrx= ntuple(2) do i
Duplicated(rx[i], drx[i])
end
rx
fill!(g, 0.0)
dg = zeros(2)
autodiff(Forward, gradient!, Const, Duplicated(g,dg), Duplicated(y,dy), Duplicated(ry, dry), tdrx...)
# H * drx
h=dg |
Here's what I came up with: julia> import Enzyme
julia> f(x...) = log(sum(exp.(x)))
f (generic function with 1 method)
julia> function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
g .= Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active.(x)...)[1]
return
end
∇f! (generic function with 1 method)
julia> function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
hess = Enzyme.autodiff(
Enzyme.Forward,
(x...) -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1],
Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
)[1]
for j in 1:N, i in 1:j
H[j, i] = hess[j][i]
end
return
end
∇²f! (generic function with 1 method)
julia> N = 3
3
julia> x, g, H = rand(N), fill(NaN, N), fill(NaN, N, N);
julia> f(x...)
1.7419820145927152
julia> ∇f!(g, x...)
julia> ∇²f!(H, x...)
julia> g
3-element Vector{Float64}:
0.24224320758303503
0.30782611265915005
0.4499306797578149
julia> H
3×3 Matrix{Float64}:
0.183561 NaN NaN
-0.0745688 0.213069 NaN
-0.108993 -0.1385 0.247493 |
I'm trying to do forward over reverse over this function
To achieve this I try to write a wrapper since I can't use Active for the Forward pass. An example is this:
This crashes with
The other version is the one I sent on Slack. That one doesn't crash, but the adjoints are not updated, although the adjoints of the output is zeroed.
The text was updated successfully, but these errors were encountered: