Skip to content

Commit

Permalink
Utilities for splitting and unsplitting mode objects (#1979)
Browse files Browse the repository at this point in the history
* Utilities for splitting and unsplitting mode objects

* Remove Manifest

* Rename and add tests

* Add ABI tests

* Fix tests

* Add set_abi on mode type

* Rename to Split and Combined
  • Loading branch information
gdalle authored Nov 26, 2024
1 parent 5c373fb commit d096464
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 1 deletion.
99 changes: 98 additions & 1 deletion lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ Subtype of [`Mode`](@ref) for split reverse mode differentiation, to use in [`au
- [`set_abi`](@ref)
- [`ReverseSplitModified`](@ref), [`ReverseSplitWidth`](@ref)
"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end

"""
const ReverseSplitNoPrimal
Expand Down Expand Up @@ -432,6 +432,9 @@ Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to
@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()
@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()

@inline set_abi(::Type{ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit}}, ::Type{NewABI}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit,NewABI<:ABI} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,NewABI,Holomorphic,ErrIfFuncWritten,ShadowInit}
@inline set_abi(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit}, ::Type{NewABI}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,OldABI,Holomorphic,ErrIfFuncWritten,ShadowInit,NewABI<:ABI} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,NewABI,Holomorphic,ErrIfFuncWritten,ShadowInit}()

@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()
@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()

Expand Down Expand Up @@ -594,6 +597,100 @@ Return a new mode with its [`ABI`](@ref) set to the chosen type.
"""
function set_abi end

"""
Split(
::ReverseMode, [::Val{ReturnShadow}, ::Val{Width}, ::Val{ModifiedBetween}, ::Val{ShadowInit}]
)
Turn a [`ReverseMode`](@ref) object into a [`ReverseModeSplit`](@ref) object while preserving as many of the settings as possible.
The rest of the settings can be configured with optional positional arguments of `Val` type.
This function acts as the identity on a [`ReverseModeSplit`](@ref).
See also [`Combined`](@ref).
"""
function Split(
::ReverseMode{
ReturnPrimal,
RuntimeActivity,
ABI,
Holomorphic,
ErrIfFuncWritten
},
::Val{ReturnShadow}=Val(true),
::Val{Width}=Val(0),
::Val{ModifiedBetween}=Val(true),
::Val{ShadowInit}=Val(false),
) where {
ReturnPrimal,
ReturnShadow,
RuntimeActivity,
Width,
ModifiedBetween,
ABI,
Holomorphic,
ErrIfFuncWritten,
ShadowInit
}
mode_split = ReverseModeSplit{
ReturnPrimal,
ReturnShadow,
RuntimeActivity,
Width,
ModifiedBetween,
ABI,
Holomorphic,
ErrIfFuncWritten,
ShadowInit
}()
return mode_split
end

Split(mode::ReverseModeSplit, args...) = mode

"""
Combined(::ReverseMode)
Turn a [`ReverseModeSplit`](@ref) object into a [`ReverseMode`](@ref) object while preserving as many of the settings as possible.
This function acts as the identity on a [`ReverseMode`](@ref).
See also [`Split`](@ref).
"""
function Combined(
::ReverseModeSplit{
ReturnPrimal,
ReturnShadow,
RuntimeActivity,
Width,
ModifiedBetween,
ABI,
Holomorphic,
ErrIfFuncWritten,
ShadowInit
}
) where {
ReturnPrimal,
ReturnShadow,
RuntimeActivity,
Width,
ModifiedBetween,
ABI,
Holomorphic,
ErrIfFuncWritten,
ShadowInit
}
mode_unsplit = ReverseMode{
ReturnPrimal,
RuntimeActivity,
ABI,
Holomorphic,
ErrIfFuncWritten
}()
return mode_unsplit
end

Combined(mode::ReverseMode) = mode

"""
Primitive Type usable within Reactant. See Reactant.jl for more information.
Expand Down
25 changes: 25 additions & 0 deletions lib/EnzymeCore/test/mode_modification.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using EnzymeCore
using EnzymeCore: InlineABI, ReverseModeSplit, Split, Combined, set_runtime_activity, set_err_if_func_written, set_abi
using Test

@testset "Split / unsplit mode" begin
@test Split(Reverse) == ReverseSplitNoPrimal
@test Split(ReverseWithPrimal) == ReverseSplitWithPrimal
@test Split(ReverseSplitNoPrimal) == ReverseSplitNoPrimal
@test Split(ReverseSplitWithPrimal) == ReverseSplitWithPrimal

@test Split(set_runtime_activity(Reverse)) == set_runtime_activity(ReverseSplitNoPrimal)
@test Split(set_err_if_func_written(Reverse)) == set_err_if_func_written(ReverseSplitNoPrimal)
@test Split(set_abi(Reverse, InlineABI)) == set_abi(ReverseSplitNoPrimal, InlineABI)

@test Split(Reverse, Val(:ReturnShadow), Val(:Width), Val(:ModifiedBetween), Val(:ShadowInit)) == ReverseModeSplit{false,:ReturnShadow,false,:Width,:ModifiedBetween,EnzymeCore.DefaultABI,false,false,:ShadowInit}()

@test Combined(Reverse) == Reverse
@test Combined(ReverseWithPrimal) == ReverseWithPrimal
@test Combined(ReverseSplitNoPrimal) == Reverse
@test Combined(ReverseSplitWithPrimal) == ReverseWithPrimal

@test Combined(set_runtime_activity(ReverseSplitNoPrimal)) == set_runtime_activity(Reverse)
@test Combined(set_err_if_func_written(ReverseSplitNoPrimal)) == set_err_if_func_written(Reverse)
@test Combined(set_abi(ReverseSplitNoPrimal, InlineABI)) == set_abi(Reverse, InlineABI)
end
3 changes: 3 additions & 0 deletions lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ using EnzymeCore
@testset "Miscellaneous" begin
include("misc.jl")
end
@testset "Mode modification" begin
include("mode_modification.jl")
end
end

0 comments on commit d096464

Please sign in to comment.