Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tape type mechanisms with parent_job capability #734

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ function autodiff_deferred end
function autodiff_thunk end
function autodiff_deferred_thunk end

function tape_type end

include("rules.jl")

end # module EnzymeCore
36 changes: 34 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplic
import EnzymeCore: batch_size
export batch_size

import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk
export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk
import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type
export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type

export jacobian, gradient, gradient!
export markType, batch_size, onehot, chunkedonehot
Expand Down Expand Up @@ -520,6 +520,38 @@ result, ∂v, ∂A
Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false))
end

@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT}, ::Type{FA}, ::Type{A}, args...; parent_job=nothing) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT}
# args′ = annotate(args...)
width = if Width == 0
w = same_or_one(args...)
if w == 0
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
w
else
Width
end

if ModifiedBetweenT === true
ModifiedBetween = falses_from_args(Val(1), args...)
else
ModifiedBetween = ModifiedBetweenT
end

tt = Tuple{map(eltype, args)...}

world = GPUCompiler.codegen_world_age(eltype(FA), tt)

@assert ReturnShadow
if parent_job !== nothing
forward, reverse = Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), Val(ModifiedBetween), #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false))
Compiler.get_tape_type(forward)
else
thunk, rt = Enzyme.Compiler.innerthunk(world, FA, A, Tuple{args...}, #=Split=#API.DEM_ReverseModeGradient, width, ModifiedBetween, #=ReturnPrimal=#ReturnPrimal, #=ShadowInit=#false; parent_job)
thunk.TapeType
end
end

"""
autodiff_thunk(::ForwardMode, ftype, Activity, argtypes...)

Expand Down
69 changes: 46 additions & 23 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ end
@inline return_type(::Type{AugmentedForwardThunk{FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {FA, RT, TT, Width, ReturnPrimal, TapeType} = RT
@inline get_tape_type(::Type{AugmentedForwardThunk{FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType
@inline get_tape_type(::Type{AdjointThunk{FA, RT, TT, Width, TapeType}}) where {FA, RT, TT, Width, TapeType} = TapeType
@inline get_tape_type(::AugmentedForwardThunk{FA, RT, TT, Width, ReturnPrimal, TapeType}) where {FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType
@inline get_tape_type(::AdjointThunk{FA, RT, TT, Width, TapeType}) where {FA, RT, TT, Width, TapeType} = TapeType

using .JIT

Expand Down Expand Up @@ -2716,7 +2718,9 @@ end
ctx = LLVM.context(orig)

llvmfn = LLVM.called_value(orig)
mi = nothing

mi, job = enzyme_custom_extract_mi(orig)

fwdmodenm = nothing
augfwdnm = nothing
adjointnm = nothing
Expand Down Expand Up @@ -2774,8 +2778,8 @@ end
subfunc = nothing
if mode == API.DEM_ForwardMode
if fwdmodenm === nothing
etarget = Compiler.EnzymeTarget()
eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType)
etarget = Compiler.EnzymeTarget(job.config.target.parent_target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
etarget = Compiler.EnzymeTarget(job.config.target.parent_target)
etarget = job.config.target

eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, GPUCompiler.method_table(job))
ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


jctx = ctx
Expand Down Expand Up @@ -2827,9 +2831,8 @@ end
end

if augfwdnm === nothing || adjointnm === nothing
etarget = Compiler.EnzymeTarget()
# TODO modifiedBetween
eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType)
etarget = Compiler.EnzymeTarget(job.config.target.parent_target)
eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, GPUCompiler.method_table(job))
ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world)
jctx = ctx
@static if VERSION < v"1.9-"
Expand Down Expand Up @@ -3158,7 +3161,7 @@ function newtask_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef,
LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, ops[1], B)),
LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[2])),
emit_box_int64!(B, LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[3]))),
unsafe_to_llvm(Val(width), ctx),
unsafe_to_llvm(Val(Int(width)), ctx),
]

ntask = emit_apply_generic!(B, vals)
Expand Down Expand Up @@ -3213,7 +3216,7 @@ function newtask_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRe
LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, ops[1], B)),
LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[2])),
emit_box_int64!(B, LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[3]))),
unsafe_to_llvm(Val(width), ctx),
unsafe_to_llvm(Val(Int(width)), ctx),
unsafe_to_llvm(Val(ModifiedBetween), ctx),
]

Expand Down Expand Up @@ -6197,10 +6200,25 @@ end

# Define EnzymeTarget
Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget
parent_target::Union{Nothing, AbstractCompilerTarget}
end
GPUCompiler.llvm_triple(T::EnzymeTarget) = Sys.MACHINE ? T.parent_target === nothing : GPUCompiler.llvm_triple(T.parent_target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
GPUCompiler.llvm_triple(T::EnzymeTarget) = Sys.MACHINE ? T.parent_target === nothing : GPUCompiler.llvm_triple(T.parent_target)
GPUCompiler.llvm_triple(T::EnzymeTarget) = T.parent_target === nothing ? Sys.MACHINE : GPUCompiler.llvm_triple(T.parent_target)

GPUCompiler.llvm_datalayout(T::EnzymeTarget) = LLVM.DataLayout(GPUCompiler.llvm_machine(T)) ? T.parent_target === nothing : GPUCompiler.llvm_datalayout(T.parent_target)
GPUCompiler.have_fma(T::EnzymeTarget) = false ? T.parent_target === nothing : GPUCompiler.have_fma(T.parent_target)

# GPUCompiler.isintrinsic(@nospecialize(job::CompilerJob{EnzymeTarget}), fn::String) = false if job.target.parent_target === nothing else GPUCompiler.isintrinsic(job.target.parent_target, fn)
# GPUCompiler.runtime_slug(@nospecialize(job::CompilerJob{EnzymeTarget})) = "enzyme" * ("" if job.target.parent_target === nothing else GPUCompiler.runtime_slug(job.target.parent_target))

function GPUCompiler.process_module!(@nospecialize(job::CompilerJob{EnzymeTarget}), mod::LLVM.Module)
if job.target.parent_target !== nothing
# process_module!(similar(job, job.target.parent_target), mod)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be correct?

end
end
GPUCompiler.llvm_triple(::EnzymeTarget) = Sys.MACHINE

# GPUCompiler.llvm_datalayout(::EnzymeTarget) = nothing

# TODO: encode debug build or not in the compiler job
# https://github.com/JuliaGPU/CUDAnative.jl/issues/368
GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme"

function GPUCompiler.llvm_machine(::EnzymeTarget)
return tm[]
Expand All @@ -6225,7 +6243,9 @@ struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams
# Whether to (in aug fwd) += by one
shadowInit::Bool
expectedTapeType::Type
method_table
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
end
GPUCompiler.method_table(@nospecialize(job::CompilerJob{T,EnzymeCompilerParams})) where T = job.config.params.method_table !== nothing ? job.config.params.method_table : GPUCompiler.GLOBAL_METHOD_TABLE

struct UnknownTapeType end

Expand All @@ -6244,9 +6264,6 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams})
# GPUCompiler.isintrinsic(::CompilerJob{EnzymeTarget}, fn::String) = true
# GPUCompiler.can_throw(::CompilerJob{EnzymeTarget}) = true

# TODO: encode debug build or not in the compiler job
# https://github.com/JuliaGPU/CUDAnative.jl/issues/368
GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme"

# provide a specific interpreter to use.
GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
Expand Down Expand Up @@ -8711,11 +8728,11 @@ end
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed

@generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World}
@inline function innerthunk(World::UInt, FA::Type{<:Annotation}, A::Type{<:Annotation}, TT::Type{<:Tuple}, Mode::API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{N, Bool} where N, ReturnPrimal::Bool, ShadowInit::Bool, parent_job=nothing)
mi = fspec(eltype(FA), TT, World)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType)
target = Compiler.EnzymeTarget(parent_job !== nothing ? parent_job.target : nothing)
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, parent_job !== nothing ? GPUCompiler.method_table(parent_job) : nothing)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

sig = Tuple{eltype(FA), map(eltype, TT.parameters)...}
Expand Down Expand Up @@ -8750,25 +8767,31 @@ end
# This is counter-intuitive since we would expect the cache to be split
# by the primal, but we want the generated code to be invalidated by
# invalidations of the primal, which is managed by GPUCompiler.
thunk = cached_compilation(job)::Thunk

return thunk, rt
end

@generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, ::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World}
parent_job=nothing
thunk, rt = innerthunk(World, FA, A, TT, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit, parent_job)

thunk = cached_compilation(job)::Thunk
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
TapeType = thunk.TapeType
AugT = AugmentedForwardThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal), TapeType}
AdjT = AdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, TapeType}
AugT = AugmentedForwardThunk{FA, rt, TT, Val{width}, Val(ReturnPrimal), TapeType}
AdjT = AdjointThunk{FA, rt, TT, Val{width}, TapeType}
return quote
augmented = $AugT($(thunk.primal))
adjoint = $AdjT($(thunk.adjoint))
(augmented, adjoint)
end
elseif Mode == API.DEM_ReverseModeCombined
CAdjT = CombinedAdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)}
CAdjT = CombinedAdjointThunk{FA, rt, TT, Val{width}, Val(ReturnPrimal)}
return quote
$CAdjT($(thunk.adjoint))
end
elseif Mode == API.DEM_ForwardMode
FMT = ForwardModeThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)}
FMT = ForwardModeThunk{FA, rt, TT, Val{width}, Val(ReturnPrimal)}
return quote
$FMT($(thunk.adjoint))
end
Expand All @@ -8782,8 +8805,8 @@ import GPUCompiler: deferred_codegen_jobs
@generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{tt}, ::Val{rt},::Val{Mode},
::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {World, FA<:Annotation,tt, rt, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType}
mi = fspec(eltype(FA), tt, World)
target = EnzymeTarget()
params = EnzymeCompilerParams(Tuple{FA, tt.parameters...}, Mode, width, remove_innerty(rt), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType)
target = EnzymeTarget(nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this NativeTarget then we don't need to worry about nothing vs defined?

params = EnzymeCompilerParams(Tuple{FA, tt.parameters...}, Mode, width, remove_innerty(rt), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, nothing)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

adjoint_addr, primal_addr = get_trampoline(job)
Expand Down
7 changes: 4 additions & 3 deletions src/compiler/pmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ function commonInnerCompile(runtime_fn, B, orig, gutils, tape, mode)
mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
ctx = LLVM.context(orig)

mi, job = enzyme_custom_extract_mi(orig)

llvmfn = LLVM.called_value(orig)
mi = nothing
adjointnm = nothing
augfwdnm = nothing
TapeType = nothing
Expand Down Expand Up @@ -159,11 +160,11 @@ function commonInnerCompile(runtime_fn, B, orig, gutils, tape, mode)

if augfwdnm === nothing
# TODO: Clean this up and add to `nested_codegen!` asa feature
etarget = Compiler.EnzymeTarget()
etarget = Compiler.EnzymeTarget(job.config.target.parent_target)
funcOverwritten = true
indexOverwritten = false
eparams = Compiler.EnzymeCompilerParams(Tuple{Const{funcT}, dup...}, API.DEM_ReverseModePrimal, width, Const{RT}, true,
#=abiwrap=#true, #=modifiedBetween=#(funcOverwritten, indexOverwritten, overwritten...,), #=returnPrimal=#false, #=shadowprimalInit=#false, Compiler.UnknownTapeType)
#=abiwrap=#true, #=modifiedBetween=#(funcOverwritten, indexOverwritten, overwritten...,), #=returnPrimal=#false, #=shadowprimalInit=#false, Compiler.UnknownTapeType, GPUCompiler.method_table(job))
ejob = Compiler.CompilerJob(eprimal, CompilerConfig(etarget, eparams; kernel=false), world)

jctx = ctx
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types);
run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, kwargs...)
run_enzyme::Bool=true, mode::API.CDerivativeMode=API.DEM_ReverseModeCombined, dupClosure::Bool=false, argwrap::Bool=true, width::Int=1, modifiedBetween=nothing, returnPrimal::Bool=false, augmentedInit=false, world=nothing, parent_target=nothing, method_table=nothing, kwargs...)

tt = Tuple{map(eltype, types.parameters)...}
if world === nothing
Expand All @@ -10,12 +10,12 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types);

rt = Core.Compiler.return_type(func, tt, world)
rt = A{rt}
target = Compiler.EnzymeTarget()
target = Compiler.EnzymeTarget(parent_target)
if modifiedBetween === nothing
defaultMod = mode != API.DEM_ReverseModeCombined && mode != API.DEM_ForwardMode
modifiedBetween = (defaultMod, (defaultMod for _ in types.parameters)...)
end
params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, remove_innerty(rt), run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType)
params = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){Core.Typeof(func)}, types.parameters...}, mode, width, remove_innerty(rt), run_enzyme, argwrap, modifiedBetween, returnPrimal, augmentedInit, Compiler.UnknownTapeType, method_table)
return Compiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world)
end

Expand Down