Skip to content

Commit

Permalink
Merge branch 'main' into thunkua
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 15, 2024
2 parents fbc05f7 + ee21090 commit 9990193
Show file tree
Hide file tree
Showing 14 changed files with 357 additions and 271 deletions.
24 changes: 12 additions & 12 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
FTy = Core.Typeof(f.val)

rt = if A isa UnionAll
Compiler.primal_return_type(mode, FTy, tt)
Compiler.primal_return_type(Reverse, FTy, tt)
else
eltype(A)
end
Expand Down Expand Up @@ -410,7 +410,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
end

opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Reverse, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -536,7 +536,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value.
) where {FA<:Annotation,CMode<:Mode,Nargs}
tt = vaEltypeof(args...)
rt = Compiler.primal_return_type(
mode,
mode isa ForwardMode ? Forward : Reverse,
eltype(FA),
tt,
)
Expand Down Expand Up @@ -632,7 +632,7 @@ f(x) = x*x
tt = vaEltypeof(args...)

opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Forward, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -687,7 +687,7 @@ code, as well as high-order differentiation.
A2 = A

if A isa UnionAll
rt = Compiler.primal_return_type(mode, FTy, tt)
rt = Compiler.primal_return_type(Reverse, FTy, tt)
A2 = A{rt}
if rt == Union{}
rt = Nothing
Expand Down Expand Up @@ -840,7 +840,7 @@ code, as well as high-order differentiation.
FT = Core.Typeof(f.val)

if RT isa UnionAll
rt = Compiler.primal_return_type(mode, FT, tt)
rt = Compiler.primal_return_type(Forward, FT, tt)
if rt == Union{}
rt = Nothing
end
Expand Down Expand Up @@ -968,7 +968,7 @@ result, ∂v, ∂A

tt′ = Tuple{args...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Reverse, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -1098,7 +1098,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float

tt′ = Tuple{args...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Forward, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -1166,7 +1166,7 @@ end

primal_tt = Tuple{map(eltype, args)...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), TT)
my_methodinstance(Forward, eltype(FA), primal_tt)
else
Val(0)
end
Expand Down Expand Up @@ -1196,7 +1196,7 @@ const tape_cache = Dict{UInt,Type}()

const tape_cache_lock = ReentrantLock()

import .Compiler: fspec, remove_innerty, UnknownTapeType
import .Compiler: remove_innerty, UnknownTapeType

@inline function tape_type(
parent_job::Union{GPUCompiler.CompilerJob,Nothing},
Expand Down Expand Up @@ -1246,7 +1246,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType

primal_tt = Tuple{map(eltype, args)...}

mi = Compiler.fspec(eltype(FA), TT)
mi = my_methodinstance(parent_job === nothing ? Reverse : GPUCompiler.get_interpreter(parent_job), eltype(FA), primal_tt)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(
Expand Down Expand Up @@ -1383,7 +1383,7 @@ result, ∂v, ∂A

rt = if A2 isa UnionAll
primal_tt = Tuple{map(eltype, args)...}
rt0 = Compiler.primal_return_type(mode, eltype(FA), primal_tt)
rt0 = Compiler.primal_return_type(Reverse, eltype(FA), primal_tt)
A2{rt0}
else
A2
Expand Down
1 change: 1 addition & 0 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ end
EnzymeCore.EnzymeRules.inactive_type(T)
else
inmi = my_methodinstance(
nothing,
typeof(EnzymeCore.EnzymeRules.inactive_type),
Tuple{Type{T}},
world,
Expand Down
Loading

0 comments on commit 9990193

Please sign in to comment.