diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index e5a8d7cf6..df1dd8601 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.26" +version = "0.6.27" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index d9ca885b5..578622e65 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -6,7 +6,11 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.NoPullbackPrep() end @@ -17,7 +21,7 @@ function DI.prepare_pullback_same_point( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -30,7 +34,7 @@ function DI.value_and_pullback( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -46,7 +50,7 @@ function DI.value_and_pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -61,7 +65,7 @@ function DI.pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 2ef5364ae..328bffaf3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -7,6 +7,7 @@ using EnzymeCore: Active, Annotation, BatchDuplicated, + BatchDuplicatedNoNeed, BatchMixedDuplicated, Combined, Const, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 433769eba..00b66808a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -18,12 +18,12 @@ function DI.value_and_pushforward( tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = forward_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) dx_sametype = convert(typeof(x), only(tx)) x_and_dx = Duplicated(x, dx_sametype) - dy, y = autodiff( - forward_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)... - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...) return y, (dy,) end @@ -35,12 +35,12 @@ function DI.value_and_pushforward( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = get_f_and_df(f, backend, Val(B)) + mode = forward_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) x_and_tx = BatchDuplicated(x, tx_sametype) - ty, y = autodiff( - forward_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)... - ) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...) return y, values(ty) end @@ -52,12 +52,12 @@ function DI.pushforward( tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = forward_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) dx_sametype = convert(typeof(x), only(tx)) x_and_dx = Duplicated(x, dx_sametype) - dy = only( - autodiff(forward_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...) - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)) return (dy,) end @@ -69,12 +69,12 @@ function DI.pushforward( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = get_f_and_df(f, backend, Val(B)) + mode = forward_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) x_and_tx = BatchDuplicated(x, tx_sametype) - ty = only( - autodiff(forward_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...) - ) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)) return values(ty) end @@ -132,10 +132,9 @@ function DI.gradient( backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) - derivs = gradient( - forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + mode = forward_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) + derivs = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) return only(derivs) end @@ -145,10 +144,9 @@ function DI.value_and_gradient( backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) - (; derivs, val) = gradient( - forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + mode = forward_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) + (; derivs, val) = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) return val, only(derivs) end @@ -201,10 +199,9 @@ function DI.jacobian( backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) - derivs = jacobian( - forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + mode = forward_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) + derivs = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) jac_tensor = only(derivs) return maybe_reshape(jac_tensor, prep.output_length, length(x)) end @@ -215,10 +212,9 @@ function DI.value_and_jacobian( backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) - (; derivs, val) = jacobian( - forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + mode = forward_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) + (; derivs, val) = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) jac_tensor = only(derivs) return val, maybe_reshape(jac_tensor, prep.output_length, length(x)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 58e77d25f..b185b4a16 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -20,19 +20,14 @@ function DI.value_and_pushforward( tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f!_and_df! = get_f_and_df(f!, backend) + mode = forward_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode) dx_sametype = convert(typeof(x), only(tx)) dy_sametype = make_zero(y) x_and_dx = Duplicated(x, dx_sametype) y_and_dy = Duplicated(y, dy_sametype) - autodiff( - forward_noprimal(backend), - f!_and_df!, - Const, - y_and_dy, - x_and_dx, - map(translate, contexts)..., - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...) return y, (dy_sametype,) end @@ -45,19 +40,14 @@ function DI.value_and_pushforward( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f!_and_df! = get_f_and_df(f!, backend, Val(B)) + mode = forward_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) ty_sametype = ntuple(_ -> make_zero(y), Val(B)) x_and_tx = BatchDuplicated(x, tx_sametype) y_and_ty = BatchDuplicated(y, ty_sametype) - autodiff( - forward_noprimal(backend), - f!_and_df!, - Const, - y_and_ty, - x_and_tx, - map(translate, contexts)..., - ) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, ty_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 660d40520..5585f717e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -69,13 +69,14 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = force_annotation(get_f_and_df(f, backend)) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode)) IA = guess_activity(typeof(x), mode) RA = guess_activity(eltype(ty), mode) dx = make_zero(x) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs, result = seeded_autodiff_thunk( - mode, only(ty), f_and_df, RA, annotate(IA, x, dx), map(translate, contexts)... + mode, only(ty), f_and_df, RA, annotate(IA, x, dx), annotated_contexts... ) new_dx = first(dinputs) if isnothing(new_dx) @@ -93,13 +94,14 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) IA = batchify_activity(guess_activity(typeof(x), mode), Val(B)) RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx = ntuple(_ -> make_zero(x), Val(B)) + annotated_contexts = translate(backend, mode, Val(B), contexts...) dinputs, result = batch_seeded_autodiff_thunk( - mode, ty, f_and_df, RA, annotate(IA, x, tx), map(translate, contexts)... + mode, ty, f_and_df, RA, annotate(IA, x, tx), annotated_contexts... ) new_tx = values(first(dinputs)) if isnothing(new_tx) @@ -131,18 +133,14 @@ function DI.value_and_pullback!( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = force_annotation(get_f_and_df(f, backend)) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode)) RA = guess_activity(eltype(ty), mode) dx_righttype = convert(typeof(x), only(tx)) make_zero!(dx_righttype) + annotated_contexts = translate(backend, mode, Val(1), contexts...) _, result = seeded_autodiff_thunk( - mode, - only(ty), - f_and_df, - RA, - Duplicated(x, dx_righttype), - map(translate, contexts)..., + mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts... ) only(tx) === dx_righttype || copyto!(only(tx), dx_righttype) return result, tx @@ -157,18 +155,14 @@ function DI.value_and_pullback!( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx_righttype = map(Fix1(convert, typeof(x)), tx) make_zero!(tx_righttype) + annotated_contexts = translate(backend, mode, Val(B), contexts...) _, result = batch_seeded_autodiff_thunk( - mode, - ty, - f_and_df, - RA, - BatchDuplicated(x, tx_righttype), - map(translate, contexts)..., + mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts... ) foreach(copyto!, tx, tx_righttype) return result, tx @@ -196,12 +190,13 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) mode = reverse_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) grad = make_zero(x) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs = only( - autodiff(mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...) + autodiff(mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...) ) new_grad = first(dinputs) if isnothing(new_grad) @@ -217,12 +212,13 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) mode = reverse_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) grad = make_zero(x) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs, result = autodiff( - mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)... + mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts... ) new_grad = first(dinputs) if isnothing(new_grad) @@ -263,16 +259,12 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = reverse_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype make_zero!(grad_righttype) - autodiff( - reverse_noprimal(backend), - f_and_df, - Active, - Duplicated(x, grad_righttype), - map(translate, contexts)..., - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + autodiff(mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...) grad === grad_righttype || copyto!(grad, grad_righttype) return grad end @@ -295,15 +287,13 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = reverse_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype make_zero!(grad_righttype) + annotated_contexts = translate(backend, mode, Val(1), contexts...) _, y = autodiff( - reverse_withprimal(backend), - f_and_df, - Active, - Duplicated(x, grad_righttype), - map(translate, contexts)..., + mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts... ) grad === grad_righttype || copyto!(grad, grad_righttype) return y, grad diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 4da3ece8c..c93d04dab 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -20,18 +20,13 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f!_and_df! = get_f_and_df(f!, backend) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode) dy_sametype = convert(typeof(y), copy(only(ty))) y_and_dy = Duplicated(y, dy_sametype) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs = only( - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_dy, - Active(x), - map(translate, contexts)..., - ), + autodiff(mode, f!_and_df!, Const, y_and_dy, Active(x), annotated_contexts...) ) dx = dinputs[2] return y, (dx,) @@ -46,18 +41,13 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f!_and_df! = get_f_and_df(f!, backend, Val(B)) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) ty_sametype = map(Fix1(convert, typeof(y)), copy.(ty)) y_and_ty = BatchDuplicated(y, ty_sametype) + annotated_contexts = translate(backend, mode, Val(B), contexts...) dinputs = only( - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_ty, - Active(x), - map(translate, contexts)..., - ), + autodiff(mode, f!_and_df!, Const, y_and_ty, Active(x), annotated_contexts...) ) tx = values(dinputs[2]) return y, tx @@ -72,19 +62,14 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f!_and_df! = get_f_and_df(f!, backend) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode) dx_sametype = make_zero(x) dy_sametype = convert(typeof(y), copy(only(ty))) x_and_dx = Duplicated(x, dx_sametype) y_and_dy = Duplicated(y, dy_sametype) - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_dy, - x_and_dx, - map(translate, contexts)..., - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...) return y, (dx_sametype,) end @@ -97,18 +82,13 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f!_and_df! = get_f_and_df(f!, backend, Val(B)) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) tx_sametype = ntuple(_ -> make_zero(x), Val(B)) ty_sametype = map(Fix1(convert, typeof(y)), copy.(ty)) x_and_tx = BatchDuplicated(x, tx_sametype) y_and_ty = BatchDuplicated(y, ty_sametype) - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_ty, - x_and_tx, - map(translate, contexts)..., - ) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, tx_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 6c847faf9..dd8221768 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -8,15 +8,19 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B) ## Annotations -function get_f_and_df(f::F, ::AutoEnzyme{M,Nothing}, ::Val{B}=Val(1)) where {F,M,B} +@inline function get_f_and_df( + f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) +) where {F,M,B} return f end -function get_f_and_df(f::F, ::AutoEnzyme{M,<:Const}, ::Val{B}=Val(1)) where {F,M,B} +@inline function get_f_and_df( + f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) +) where {F,M,B} return Const(f) end -function get_f_and_df( +@inline function get_f_and_df( f::F, ::AutoEnzyme{ M, @@ -25,10 +29,11 @@ function get_f_and_df( MixedDuplicated, BatchDuplicated, BatchMixedDuplicated, - EnzymeCore.DuplicatedNoNeed, - EnzymeCore.BatchDuplicatedNoNeed, + DuplicatedNoNeed, + BatchDuplicatedNoNeed, }, }, + mode::Mode, ::Val{B}=Val(1), ) where {F,M,B} # TODO: needs more sophistication for mixed activities @@ -42,7 +47,26 @@ end force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) -translate(c::DI.Constant) = Const(DI.unwrap(c)) +@inline function _translate( + ::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} +) where {B} + return Const(DI.unwrap(c)) +end + +@inline function _translate( + backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext +) where {B} + return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B))) +end + +@inline function translate( + backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context,C} +) where {B,C} + new_contexts = map(contexts) do c + _translate(backend, mode, Val(B), c) + end + return new_contexts +end ## Modes diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 90372a9e2..f50f030a7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -170,7 +170,7 @@ struct ForwardDiffOneArgDerivativePrep{E} <: DI.DerivativePrep end function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) return ForwardDiffOneArgDerivativePrep(pushforward_prep) @@ -181,7 +181,7 @@ function DI.value_and_derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -195,7 +195,7 @@ function DI.value_and_derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -208,7 +208,7 @@ function DI.derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) @@ -221,7 +221,7 @@ function DI.derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der @@ -232,7 +232,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + grad, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -248,7 +252,10 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -262,7 +269,11 @@ function DI.value_and_gradient( end function DI.gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + grad, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -274,7 +285,10 @@ function DI.gradient!( end function DI.gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -292,7 +306,10 @@ struct ForwardDiffGradientPrep{C} <: DI.GradientPrep end function DI.prepare_gradient( - f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff, + x::AbstractArray, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -307,7 +324,7 @@ function DI.value_and_gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -323,7 +340,7 @@ function DI.value_and_gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = GradientResult(x) @@ -338,7 +355,7 @@ function DI.gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -350,7 +367,7 @@ function DI.gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -362,7 +379,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -379,7 +400,10 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -391,7 +415,11 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -403,7 +431,10 @@ function DI.jacobian!( end function DI.jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -421,7 +452,7 @@ struct ForwardDiffOneArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -436,7 +467,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -453,7 +484,7 @@ function DI.value_and_jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -466,7 +497,7 @@ function DI.jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -478,7 +509,7 @@ function DI.jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -488,7 +519,7 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} return DI.NoSecondDerivativePrep() end @@ -498,7 +529,7 @@ function DI.second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -513,7 +544,7 @@ function DI.second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -527,7 +558,7 @@ function DI.value_derivative_and_second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -546,7 +577,7 @@ function DI.value_derivative_and_second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -561,7 +592,11 @@ end ## HVP function DI.prepare_hvp( - f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.prepare_hvp(f, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -572,7 +607,7 @@ function DI.hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.hvp(f, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -584,7 +619,7 @@ function DI.hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.hvp!(f, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -595,7 +630,7 @@ function DI.gradient_and_hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(backend, backend), x, tx, contexts... @@ -610,7 +645,7 @@ function DI.gradient_and_hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts... @@ -622,7 +657,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.hessian!( - f::F, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + hess, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -634,7 +673,10 @@ function DI.hessian!( end function DI.hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -651,7 +693,7 @@ function DI.value_gradient_and_hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -668,7 +710,10 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -689,7 +734,7 @@ struct ForwardDiffHessianPrep{C1,C2} <: DI.HessianPrep end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -706,7 +751,7 @@ function DI.hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -718,7 +763,7 @@ function DI.hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -732,7 +777,7 @@ function DI.value_gradient_and_hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -749,7 +794,7 @@ function DI.value_gradient_and_hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl index 35e743d89..d01efd074 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl @@ -16,9 +16,9 @@ function DI.prepare_hvp( inner_gradient_prep = DI.prepare_gradient(f, DI.inner(backend), xdual, contexts...) rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), - DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.FunctionContext(f), + PrepContext(inner_gradient_prep), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -39,9 +39,9 @@ function DI.hvp( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), - DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.FunctionContext(f), + PrepContext(inner_gradient_prep), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -67,9 +67,9 @@ function DI.hvp!( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), - DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.FunctionContext(f), + PrepContext(inner_gradient_prep), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -96,9 +96,9 @@ function DI.gradient_and_hvp( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), - DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.FunctionContext(f), + PrepContext(inner_gradient_prep), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -125,9 +125,9 @@ function DI.gradient_and_hvp!( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), - DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.FunctionContext(f), + PrepContext(inner_gradient_prep), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 2fdbc8839..5ffbf8412 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -111,7 +111,11 @@ end ### Unprepared, only when tag is not specified function DI.value_and_derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -125,7 +129,12 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + der, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -139,7 +148,11 @@ function DI.value_and_derivative!( end function DI.derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -151,7 +164,12 @@ function DI.derivative( end function DI.derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + der, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -169,7 +187,11 @@ struct ForwardDiffTwoArgDerivativePrep{C} <: DI.DerivativePrep end function DI.prepare_derivative( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) tag = get_tag(fc!, backend, x) @@ -183,7 +205,7 @@ function DI.value_and_derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) @@ -199,7 +221,7 @@ function DI.value_and_derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) @@ -214,7 +236,7 @@ function DI.derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -228,7 +250,7 @@ function DI.derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -240,7 +262,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -255,7 +281,12 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -269,7 +300,11 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -281,7 +316,12 @@ function DI.jacobian( end function DI.jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -299,7 +339,11 @@ struct ForwardDiffTwoArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) chunk = choose_chunk(backend, x) @@ -314,7 +358,7 @@ function DI.value_and_jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -331,7 +375,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) @@ -346,7 +390,7 @@ function DI.jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -360,7 +404,7 @@ function DI.jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index aa2546812..b4893ef36 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -77,8 +77,15 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} return ty end -_translate(::Type{T}, ::Val{B}, c::DI.Constant) where {T,B} = DI.unwrap(c) -_translate(::Type{T}, ::Val{B}, c::DI.PrepContext) where {T,B} = DI.unwrap(c) +# store preparation result with the right input eltype +struct PrepContext{T<:DI.Prep} <: DI.Context + data::T +end + +function _translate(::Type{T}, ::Val{B}, c::DI.ConstantOrFunctionOrBackend) where {T,B} + return DI.unwrap(c) +end +_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = DI.unwrap(c) function _translate(::Type{T}, ::Val{B}, c::DI.Cache) where {T,B} c0 = DI.unwrap(c) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 0f84df11f..fb5da6b76 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -15,20 +15,30 @@ struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoPullbackPrep() end function DI.prepare_pullback_same_point( - f, ::DI.NoPullbackPrep, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(y, pb) end function DI.value_and_pullback( - f, ::DI.NoPullbackPrep, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -43,7 +53,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -58,7 +68,7 @@ function DI.pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -69,19 +79,29 @@ end ## Gradient -function DI.prepare_gradient(f, ::AutoTracker, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_gradient( + f, ::AutoTracker, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoTracker, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoTracker, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) @@ -93,7 +113,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -105,7 +125,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index b2f23cd65..0f681c9f9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -17,20 +17,30 @@ struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoPullbackPrep() end function DI.prepare_pullback_same_point( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) return ZygotePullbackPrepSamePoint(y, pb) end function DI.value_and_pullback( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -45,7 +55,7 @@ function DI.value_and_pullback( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -60,7 +70,7 @@ function DI.pullback( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -71,19 +81,29 @@ end ## Gradient -function DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_gradient( + f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, first(grad) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return first(gradient(f, x, map(DI.unwrap, contexts)...)) end @@ -94,7 +114,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -106,39 +126,59 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end ## Jacobian -function DI.prepare_jacobian(f, ::AutoZygote, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_jacobian( + f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoJacobianPrep() end function DI.value_and_jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoJacobianPrep, + ::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return f(x, map(DI.unwrap, contexts)...), first(jacobian(f, x, map(DI.unwrap, contexts)...)) # https://github.com/FluxML/Zygote.jl/issues/1506 end function DI.jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoJacobianPrep, + ::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return first(jacobian(f, x, map(DI.unwrap, contexts)...)) end function DI.value_and_jacobian!( - f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::DI.NoJacobianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end function DI.jacobian!( - f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::DI.NoJacobianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -148,13 +188,22 @@ end # Beware, this uses ForwardDiff for the inner differentiation function DI.prepare_hvp( - f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.prepare_hvp(f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end function DI.hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f, + prep::DI.HVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -166,7 +215,7 @@ function DI.hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.hvp!( f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -174,7 +223,12 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f, + prep::DI.HVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -189,7 +243,7 @@ function DI.gradient_and_hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -198,25 +252,40 @@ end ## Hessian -function DI.prepare_hessian(f, ::AutoZygote, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_hessian( + f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoHessianPrep() end function DI.hessian( - f, ::DI.NoHessianPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoHessianPrep, + ::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} fc = DI.with_contexts(f, contexts...) return hessian(fc, x) end function DI.hessian!( - f, hess, prep::DI.NoHessianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + hess, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end function DI.value_gradient_and_hessian( - f, prep::DI.NoHessianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) @@ -230,7 +299,7 @@ function DI.value_gradient_and_hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 4f53d81d5..cea59e7ef 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -90,7 +90,9 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -107,7 +109,9 @@ function hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -124,7 +128,9 @@ function hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward!( shuffled_gradient, tg, @@ -146,7 +152,9 @@ function gradient_and_hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -164,7 +172,9 @@ function gradient_and_hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) new_grad, _ = value_and_pushforward!( shuffled_gradient, tg, @@ -193,7 +203,9 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -210,7 +222,9 @@ function hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -227,7 +241,9 @@ function hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward!( shuffled_gradient, tg, @@ -249,7 +265,9 @@ function gradient_and_hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -267,7 +285,9 @@ function gradient_and_hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) new_grad, _ = value_and_pushforward!( shuffled_gradient, tg, @@ -298,8 +318,8 @@ function _prepare_hvp_aux( ) where {F,C} rewrap = Rewrap(contexts...) new_contexts = ( - Constant(f), - Constant(inner(backend)), + FunctionContext(f), + BackendContext(inner(backend)), Constant(first(tx)), Constant(rewrap), contexts..., @@ -327,8 +347,8 @@ function hvp( outer_gradient_prep, outer(backend), x, - Constant(f), - Constant(inner(backend)), + FunctionContext(f), + BackendContext(inner(backend)), Constant(dx), Constant(rewrap), contexts..., @@ -355,8 +375,8 @@ function hvp!( outer_gradient_prep, outer(backend), x, - Constant(f), - Constant(inner(backend)), + FunctionContext(f), + BackendContext(inner(backend)), Constant(tx[b]), Constant(rewrap), contexts..., @@ -409,7 +429,9 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_pullback_prep = prepare_pullback( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -426,7 +448,9 @@ function hvp( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -443,7 +467,9 @@ function hvp!( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -459,7 +485,9 @@ function gradient_and_hvp( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return value_and_pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -477,7 +505,9 @@ function gradient_and_hvp!( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) new_grad, _ = value_and_pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 02c0628af..c39c65cf4 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -56,7 +56,9 @@ function prepare_second_derivative( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_derivative_prep = prepare_derivative( shuffled_derivative, outer(backend), x, new_contexts... ) @@ -74,7 +76,9 @@ function second_derivative( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return derivative( shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts... ) @@ -89,7 +93,9 @@ function value_derivative_and_second_derivative( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) y = f(x, map(unwrap, contexts)...) der, der2 = value_and_derivative( shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts... @@ -107,7 +113,9 @@ function second_derivative!( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return derivative!( shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts... ) @@ -124,7 +132,9 @@ function value_derivative_and_second_derivative!( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) y = f(x, map(unwrap, contexts)...) new_der, _ = value_and_derivative!( shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts... diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 310017490..1abdbdd2e 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -19,6 +19,11 @@ Abstract supertype for additional context arguments, which can be passed to diff """ abstract type Context end +unwrap(c::Context) = c.data +Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) + +## Public contexts + """ Constant @@ -53,9 +58,6 @@ end constant_maker(c) = Constant(c) maker(::Constant) = constant_maker -unwrap(c::Constant) = c.data - -Base.:(==)(c1::Constant, c2::Constant) = c1.data == c2.data """ Cache @@ -70,15 +72,20 @@ end cache_maker(c) = Cache(c) maker(::Cache) = cache_maker -unwrap(c::Cache) = c.data -Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data +## Internal contexts for passing stuff around -struct PrepContext{T<:Prep} <: Context +struct FunctionContext{T} <: Context data::T end -unwrap(c::PrepContext) = c.data +struct BackendContext{T} <: Context + data::T +end + +const ConstantOrFunctionOrBackend = Union{Constant,FunctionContext,BackendContext} + +## Context manipulation struct Rewrap{C,T} context_makers::T