-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tape type mechanisms with parent_job capability #734
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -309,6 +309,8 @@ end | |||||
@inline return_type(::Type{AugmentedForwardThunk{FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {FA, RT, TT, Width, ReturnPrimal, TapeType} = RT | ||||||
@inline get_tape_type(::Type{AugmentedForwardThunk{FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType | ||||||
@inline get_tape_type(::Type{AdjointThunk{FA, RT, TT, Width, TapeType}}) where {FA, RT, TT, Width, TapeType} = TapeType | ||||||
@inline get_tape_type(::AugmentedForwardThunk{FA, RT, TT, Width, ReturnPrimal, TapeType}) where {FA, RT, TT, Width, ReturnPrimal, TapeType} = TapeType | ||||||
@inline get_tape_type(::AdjointThunk{FA, RT, TT, Width, TapeType}) where {FA, RT, TT, Width, TapeType} = TapeType | ||||||
|
||||||
using .JIT | ||||||
|
||||||
|
@@ -2716,7 +2718,9 @@ end | |||||
ctx = LLVM.context(orig) | ||||||
|
||||||
llvmfn = LLVM.called_value(orig) | ||||||
mi = nothing | ||||||
|
||||||
mi, job = enzyme_custom_extract_mi(orig) | ||||||
|
||||||
fwdmodenm = nothing | ||||||
augfwdnm = nothing | ||||||
adjointnm = nothing | ||||||
|
@@ -2774,8 +2778,8 @@ end | |||||
subfunc = nothing | ||||||
if mode == API.DEM_ForwardMode | ||||||
if fwdmodenm === nothing | ||||||
etarget = Compiler.EnzymeTarget() | ||||||
eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType) | ||||||
etarget = Compiler.EnzymeTarget(job.config.target.parent_target) | ||||||
eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ForwardMode, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, GPUCompiler.method_table(job)) | ||||||
ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw there is a copy-constructor https://github.com/JuliaGPU/GPUCompiler.jl/blob/d5086fb3d93bbc4795a96f6f1457898af46a24cb/src/interface.jl#L111-L115 |
||||||
|
||||||
jctx = ctx | ||||||
|
@@ -2827,9 +2831,8 @@ end | |||||
end | ||||||
|
||||||
if augfwdnm === nothing || adjointnm === nothing | ||||||
etarget = Compiler.EnzymeTarget() | ||||||
# TODO modifiedBetween | ||||||
eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType) | ||||||
etarget = Compiler.EnzymeTarget(job.config.target.parent_target) | ||||||
eparams = Compiler.EnzymeCompilerParams(Tuple{(dupClosure ? Duplicated : Const){funcT}, e_tt.parameters...}, API.DEM_ReverseModePrimal, width, Const{Nothing}, #=runEnzyme=#true, #=abiwrap=#true, modifiedBetween, #=returnPrimal=#false, #=shadowInit=#false, UnknownTapeType, GPUCompiler.method_table(job)) | ||||||
ejob = Compiler.CompilerJob(mi2, CompilerConfig(etarget, eparams; kernel=false), world) | ||||||
jctx = ctx | ||||||
@static if VERSION < v"1.9-" | ||||||
|
@@ -3158,7 +3161,7 @@ function newtask_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, | |||||
LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, ops[1], B)), | ||||||
LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[2])), | ||||||
emit_box_int64!(B, LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[3]))), | ||||||
unsafe_to_llvm(Val(width), ctx), | ||||||
unsafe_to_llvm(Val(Int(width)), ctx), | ||||||
] | ||||||
|
||||||
ntask = emit_apply_generic!(B, vals) | ||||||
|
@@ -3213,7 +3216,7 @@ function newtask_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRe | |||||
LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, ops[1], B)), | ||||||
LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[2])), | ||||||
emit_box_int64!(B, LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, ops[3]))), | ||||||
unsafe_to_llvm(Val(width), ctx), | ||||||
unsafe_to_llvm(Val(Int(width)), ctx), | ||||||
unsafe_to_llvm(Val(ModifiedBetween), ctx), | ||||||
] | ||||||
|
||||||
|
@@ -6197,10 +6200,25 @@ end | |||||
|
||||||
# Define EnzymeTarget | ||||||
Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget | ||||||
parent_target::Union{Nothing, AbstractCompilerTarget} | ||||||
end | ||||||
GPUCompiler.llvm_triple(T::EnzymeTarget) = Sys.MACHINE ? T.parent_target === nothing : GPUCompiler.llvm_triple(T.parent_target) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
GPUCompiler.llvm_datalayout(T::EnzymeTarget) = LLVM.DataLayout(GPUCompiler.llvm_machine(T)) ? T.parent_target === nothing : GPUCompiler.llvm_datalayout(T.parent_target) | ||||||
GPUCompiler.have_fma(T::EnzymeTarget) = false ? T.parent_target === nothing : GPUCompiler.have_fma(T.parent_target) | ||||||
|
||||||
# GPUCompiler.isintrinsic(@nospecialize(job::CompilerJob{EnzymeTarget}), fn::String) = false if job.target.parent_target === nothing else GPUCompiler.isintrinsic(job.target.parent_target, fn) | ||||||
# GPUCompiler.runtime_slug(@nospecialize(job::CompilerJob{EnzymeTarget})) = "enzyme" * ("" if job.target.parent_target === nothing else GPUCompiler.runtime_slug(job.target.parent_target)) | ||||||
|
||||||
function GPUCompiler.process_module!(@nospecialize(job::CompilerJob{EnzymeTarget}), mod::LLVM.Module) | ||||||
if job.target.parent_target !== nothing | ||||||
# process_module!(similar(job, job.target.parent_target), mod) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That should be correct? |
||||||
end | ||||||
end | ||||||
GPUCompiler.llvm_triple(::EnzymeTarget) = Sys.MACHINE | ||||||
|
||||||
# GPUCompiler.llvm_datalayout(::EnzymeTarget) = nothing | ||||||
|
||||||
# TODO: encode debug build or not in the compiler job | ||||||
# https://github.com/JuliaGPU/CUDAnative.jl/issues/368 | ||||||
GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme" | ||||||
|
||||||
function GPUCompiler.llvm_machine(::EnzymeTarget) | ||||||
return tm[] | ||||||
|
@@ -6225,7 +6243,9 @@ struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams | |||||
# Whether to (in aug fwd) += by one | ||||||
shadowInit::Bool | ||||||
expectedTapeType::Type | ||||||
method_table | ||||||
wsmoses marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
end | ||||||
GPUCompiler.method_table(@nospecialize(job::CompilerJob{T,EnzymeCompilerParams})) where T = job.config.params.method_table !== nothing ? job.config.params.method_table : GPUCompiler.GLOBAL_METHOD_TABLE | ||||||
|
||||||
struct UnknownTapeType end | ||||||
|
||||||
|
@@ -6244,9 +6264,6 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) | |||||
# GPUCompiler.isintrinsic(::CompilerJob{EnzymeTarget}, fn::String) = true | ||||||
# GPUCompiler.can_throw(::CompilerJob{EnzymeTarget}) = true | ||||||
|
||||||
# TODO: encode debug build or not in the compiler job | ||||||
# https://github.com/JuliaGPU/CUDAnative.jl/issues/368 | ||||||
GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme" | ||||||
|
||||||
# provide a specific interpreter to use. | ||||||
GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = | ||||||
|
@@ -8711,11 +8728,11 @@ end | |||||
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated | ||||||
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed | ||||||
|
||||||
@generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World} | ||||||
@inline function innerthunk(World::UInt, FA::Type{<:Annotation}, A::Type{<:Annotation}, TT::Type{<:Tuple}, Mode::API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{N, Bool} where N, ReturnPrimal::Bool, ShadowInit::Bool, parent_job=nothing) | ||||||
mi = fspec(eltype(FA), TT, World) | ||||||
|
||||||
target = Compiler.EnzymeTarget() | ||||||
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType) | ||||||
target = Compiler.EnzymeTarget(parent_job !== nothing ? parent_job.target : nothing) | ||||||
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, parent_job !== nothing ? GPUCompiler.method_table(parent_job) : nothing) | ||||||
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) | ||||||
|
||||||
sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} | ||||||
|
@@ -8750,25 +8767,31 @@ end | |||||
# This is counter-intuitive since we would expect the cache to be split | ||||||
# by the primal, but we want the generated code to be invalidated by | ||||||
# invalidations of the primal, which is managed by GPUCompiler. | ||||||
thunk = cached_compilation(job)::Thunk | ||||||
|
||||||
return thunk, rt | ||||||
end | ||||||
|
||||||
@generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, ::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World} | ||||||
parent_job=nothing | ||||||
thunk, rt = innerthunk(World, FA, A, TT, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit, parent_job) | ||||||
|
||||||
thunk = cached_compilation(job)::Thunk | ||||||
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient | ||||||
TapeType = thunk.TapeType | ||||||
AugT = AugmentedForwardThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal), TapeType} | ||||||
AdjT = AdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, TapeType} | ||||||
AugT = AugmentedForwardThunk{FA, rt, TT, Val{width}, Val(ReturnPrimal), TapeType} | ||||||
AdjT = AdjointThunk{FA, rt, TT, Val{width}, TapeType} | ||||||
return quote | ||||||
augmented = $AugT($(thunk.primal)) | ||||||
adjoint = $AdjT($(thunk.adjoint)) | ||||||
(augmented, adjoint) | ||||||
end | ||||||
elseif Mode == API.DEM_ReverseModeCombined | ||||||
CAdjT = CombinedAdjointThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} | ||||||
CAdjT = CombinedAdjointThunk{FA, rt, TT, Val{width}, Val(ReturnPrimal)} | ||||||
return quote | ||||||
$CAdjT($(thunk.adjoint)) | ||||||
end | ||||||
elseif Mode == API.DEM_ForwardMode | ||||||
FMT = ForwardModeThunk{FA, rt, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} | ||||||
FMT = ForwardModeThunk{FA, rt, TT, Val{width}, Val(ReturnPrimal)} | ||||||
return quote | ||||||
$FMT($(thunk.adjoint)) | ||||||
end | ||||||
|
@@ -8782,8 +8805,8 @@ import GPUCompiler: deferred_codegen_jobs | |||||
@generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{tt}, ::Val{rt},::Val{Mode}, | ||||||
::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false),::Val{ShadowInit}=Val(false),::Type{ExpectedTapeType}=UnknownTapeType) where {World, FA<:Annotation,tt, rt, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType} | ||||||
mi = fspec(eltype(FA), tt, World) | ||||||
target = EnzymeTarget() | ||||||
params = EnzymeCompilerParams(Tuple{FA, tt.parameters...}, Mode, width, remove_innerty(rt), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType) | ||||||
target = EnzymeTarget(nothing) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this |
||||||
params = EnzymeCompilerParams(Tuple{FA, tt.parameters...}, Mode, width, remove_innerty(rt), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, nothing) | ||||||
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) | ||||||
|
||||||
adjoint_addr, primal_addr = get_trampoline(job) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.