Skip to content

Commit

Permalink
Fix fwd to not have ref on active rtfix (#2142)
Browse files Browse the repository at this point in the history
* Fix fwd to not have ref on active rtfix

* Update runtests.jl
  • Loading branch information
wsmoses authored Nov 29, 2024
1 parent 2bfc9b5 commit 22818bf
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ function julia_error(
end

illegalVal = nothing
mode = get_mode(gutils)

function make_replacement(@nospecialize(cur::LLVM.Value), prevbb::LLVM.IRBuilder)::LLVM.Value
ncur = new_from_original(gutils, cur)
Expand All @@ -308,15 +309,27 @@ function julia_error(
isa(cur, LLVM.ConstantExpr) &&
cur == data2
if width == 1
res = emit_allocobj!(prevbb, Base.RefValue{TT})
push!(created, res)
return res
if mode == API.DEM_ForwardMode
instance = make_zero(obj)
return unsafe_to_llvm(prevbb, instance)
else
res = emit_allocobj!(prevbb, Base.RefValue{TT})
push!(created, res)
return res
end
else
shadowres = UndefValue(
LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))),
)
for idx = 1:width
res = emit_allocobj!(prevbb, Base.RefValue{TT})
res = if mode == API.DEM_ForwardMode
instance = make_zero(obj)
unsafe_to_llvm(prevbb, instance)
else
sres = emit_allocobj!(prevbb, Base.RefValue{TT})
push!(created, sres)
sres
end
shadowres = insert_value!(prevbb, shadowres, res, idx - 1)
push!(created, shadowres)
end
Expand Down
28 changes: 28 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,34 @@ end
@test dweights[2] 20.
end


abstract type AbsFwdType end

# Two copies of the same type.
struct FwdNormal1{T<:Real} <: AbsFwdType
σ::T
end

struct FwdNormal2{T<:Real} <: AbsFwdType
σ::T
end

fwdlogpdf(d) = d.σ

function absactfunc(x)
dists = AbsFwdType[FwdNormal1{Float64}(1.0), FwdNormal2{Float64}(x)]
res = Vector{Float64}(undef, 2)
for i in 1:length(dists)
@inbounds res[i] = fwdlogpdf(dists[i])
end
return @inbounds res[1] + @inbounds res[2]
end

@testset "Forward Mode active runtime activity" begin
res = Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(absactfunc), Duplicated(2.7, 3.1))
@test res[1] 3.1
end

# dot product (https://github.com/EnzymeAD/Enzyme.jl/issues/495)
@testset "Dot product" for T in (Float32, Float64)
xx = rand(T, 10)
Expand Down

0 comments on commit 22818bf

Please sign in to comment.