From 22818bf12ff50285e306570d1725b33103d582e3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 11:28:57 -0500 Subject: [PATCH] Fix fwd to not have ref on active rtfix (#2142) * Fix fwd to not have ref on active rtfix * Update runtests.jl --- src/errors.jl | 21 +++++++++++++++++---- test/runtests.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index c6dd78b781..187e42b30c 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -286,6 +286,7 @@ function julia_error( end illegalVal = nothing + mode = get_mode(gutils) function make_replacement(@nospecialize(cur::LLVM.Value), prevbb::LLVM.IRBuilder)::LLVM.Value ncur = new_from_original(gutils, cur) @@ -308,15 +309,27 @@ function julia_error( isa(cur, LLVM.ConstantExpr) && cur == data2 if width == 1 - res = emit_allocobj!(prevbb, Base.RefValue{TT}) - push!(created, res) - return res + if mode == API.DEM_ForwardMode + instance = make_zero(obj) + return unsafe_to_llvm(prevbb, instance) + else + res = emit_allocobj!(prevbb, Base.RefValue{TT}) + push!(created, res) + return res + end else shadowres = UndefValue( LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))), ) for idx = 1:width - res = emit_allocobj!(prevbb, Base.RefValue{TT}) + res = if mode == API.DEM_ForwardMode + instance = make_zero(obj) + unsafe_to_llvm(prevbb, instance) + else + sres = emit_allocobj!(prevbb, Base.RefValue{TT}) + push!(created, sres) + sres + end shadowres = insert_value!(prevbb, shadowres, res, idx - 1) push!(created, shadowres) end diff --git a/test/runtests.jl b/test/runtests.jl index e331645378..549978b894 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1272,6 +1272,34 @@ end @test dweights[2] ≈ 20. end + +abstract type AbsFwdType end + +# Two copies of the same type. +struct FwdNormal1{T<:Real} <: AbsFwdType +σ::T +end + +struct FwdNormal2{T<:Real} <: AbsFwdType +σ::T +end + +fwdlogpdf(d) = d.σ + +function absactfunc(x) + dists = AbsFwdType[FwdNormal1{Float64}(1.0), FwdNormal2{Float64}(x)] + res = Vector{Float64}(undef, 2) + for i in 1:length(dists) + @inbounds res[i] = fwdlogpdf(dists[i]) + end + return @inbounds res[1] + @inbounds res[2] +end + +@testset "Forward Mode active runtime activity" begin + res = Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(absactfunc), Duplicated(2.7, 3.1)) + @test res[1] ≈ 3.1 +end + # dot product (https://github.com/EnzymeAD/Enzyme.jl/issues/495) @testset "Dot product" for T in (Float32, Float64) xx = rand(T, 10)