Skip to content

Commit

Permalink
Fix make_zero(!) corner case bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwe committed Oct 11, 2024
1 parent ad86689 commit a916d55
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions src/make_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
prev::Complex{RT},
::Val{copy_if_inactive} = Val(false),
)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat}
return RT(0)
return Complex{RT}(0)
end

@inline function EnzymeCore.make_zero(
Expand Down Expand Up @@ -117,11 +117,10 @@ end
return seen[prev]
end
prev2 = prev.contents
res = Core.Box()
seen[prev] = res
res.contents = Base.Ref(
res = Core.Box(
EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)),
)
seen[prev] = res
return res
end

Expand Down Expand Up @@ -160,7 +159,9 @@ end
end

if nf == 0
return prev
# Unclear what types might end up here rather than in specialized methods or
# guaranteed_const_nongen, but as a last-ditch attempt try falling back to Base.zero
return Base.zero(prev)::RT
end

flds = Vector{Any}(undef, nf)
Expand All @@ -187,27 +188,41 @@ function make_zero_immutable!(
prev::Complex{T},
seen::S,
)::Complex{T} where {T<:AbstractFloat,S}
zero(T)
zero(Complex{T})
end

function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S}
ntuple(Val(length(T.parameters))) do i
Base.@_inline_meta
make_zero_immutable!(prev[i], seen)
p = prev[i]
SBT = Core.Typeof(p)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
make_zero_immutable!(p, seen)
else
EnzymeCore.make_zero!(p, seen)
p
end
end
end

function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S}
NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i
NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i
Base.@_inline_meta
make_zero_immutable!(prev[a[i]], seen)
p = prev[a[i]]
SBT = Core.Typeof(p)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
make_zero_immutable!(p, seen)
else
EnzymeCore.make_zero!(p, seen)
p
end
end)
end


function make_zero_immutable!(prev::T, seen::S)::T where {T,S}
if guaranteed_const_nongen(T, nothing)
return prev
return prev # Note: unreachable from make_zero!
end
@assert !ismutable(prev)

Expand Down Expand Up @@ -239,15 +254,15 @@ end
prev::Base.RefValue{T},
seen::ST,
)::Nothing where {T<:AbstractFloat,ST}
T[] = zero(T)
prev[] = zero(T)
nothing
end

@inline function EnzymeCore.make_zero!(
prev::Base.RefValue{Complex{T}},
seen::ST,
)::Nothing where {T<:AbstractFloat,ST}
T[] = zero(Complex{T})
prev[] = zero(Complex{T})
nothing
end

Expand Down Expand Up @@ -297,7 +312,7 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(seen, prev)
if prev in seen
return
end
push!(seen, prev)
Expand Down Expand Up @@ -325,7 +340,7 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(seen, prev)
if prev in seen
return
end
push!(seen, prev)
Expand All @@ -348,13 +363,13 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(seen, prev)
if prev in seen
return
end
push!(seen, prev)
SBT = Core.Typeof(pv)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
prev.contents = EnzymeCore.make_zero_immutable!(pv, seen)
prev.contents = make_zero_immutable!(pv, seen)
nothing
else
EnzymeCore.make_zero!(pv, seen)
Expand All @@ -370,7 +385,7 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(prev, seen)
if prev in seen
return
end
@assert !Base.isabstracttype(T)
Expand All @@ -379,7 +394,7 @@ end


if nf == 0
return
error("cannot zero $T in-place: it is apparently differentiable but has no fields")
end

push!(seen, prev)
Expand All @@ -392,7 +407,12 @@ end
continue
end
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
setfield!(prev, i, make_zero_immutable!(xi, seen))
yi = make_zero_immutable!(xi, seen)
if Base.isconst(T, i)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi)
else
setfield!(prev, i, yi)
end
nothing
else
EnzymeCore.make_zero!(xi, seen)
Expand Down

0 comments on commit a916d55

Please sign in to comment.