Skip to content

Commit

Permalink
Split out extras types (#422)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Aug 27, 2024
1 parent 1441260 commit 6c829fc
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 74 deletions.
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ abstract type Extras end

include("second_order/second_order.jl")

include("utils/extras.jl")
include("utils/traits.jl")
include("utils/basis.jl")
include("utils/batch.jl")
Expand Down
9 changes: 0 additions & 9 deletions DifferentiationInterface/src/first_order/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,6 @@ function derivative! end

## Preparation

"""
DerivativeExtras
Abstract type for additional information needed by [`derivative`](@ref) and its variants.
"""
abstract type DerivativeExtras <: Extras end

struct NoDerivativeExtras <: DerivativeExtras end

struct PushforwardDerivativeExtras{E<:PushforwardExtras} <: DerivativeExtras
pushforward_extras::E
end
Expand Down
9 changes: 0 additions & 9 deletions DifferentiationInterface/src/first_order/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ function gradient! end

## Preparation

"""
GradientExtras
Abstract type for additional information needed by [`gradient`](@ref) and its variants.
"""
abstract type GradientExtras <: Extras end

struct NoGradientExtras <: GradientExtras end

struct PullbackGradientExtras{E<:PullbackExtras} <: GradientExtras
pullback_extras::E
end
Expand Down
9 changes: 0 additions & 9 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,6 @@ function jacobian! end

## Preparation

"""
JacobianExtras
Abstract type for additional information needed by [`jacobian`](@ref) and its variants.
"""
abstract type JacobianExtras <: Extras end

struct NoJacobianExtras <: JacobianExtras end

struct PushforwardJacobianExtras{B,D,R,E<:PushforwardExtras} <: JacobianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
Expand Down
11 changes: 1 addition & 10 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,7 @@ function pullback! end

### Extras types

"""
PullbackExtras
Abstract type for additional information needed by [`pullback`](@ref) and its variants.
"""
abstract type PullbackExtras <: Extras end

struct NoPullbackExtras <: PullbackExtras end

struct PushforwardPullbackExtras{E} <: PullbackExtras
struct PushforwardPullbackExtras{E<:PushforwardExtras} <: PullbackExtras
pushforward_extras::E
end

Expand Down
11 changes: 1 addition & 10 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,7 @@ function pushforward! end

### Extras types

"""
PushforwardExtras
Abstract type for additional information needed by [`pushforward`](@ref) and its variants.
"""
abstract type PushforwardExtras <: Extras end

struct NoPushforwardExtras <: PushforwardExtras end

struct PullbackPushforwardExtras{E} <: PushforwardExtras
struct PullbackPushforwardExtras{E<:PullbackExtras} <: PushforwardExtras
pullback_extras::E
end

Expand Down
9 changes: 0 additions & 9 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ function value_gradient_and_hessian! end

## Preparation

"""
HessianExtras
Abstract type for additional information needed by [`hessian`](@ref) and its variants.
"""
abstract type HessianExtras <: Extras end

struct NoHessianExtras <: HessianExtras end

struct HVPGradientHessianExtras{B,D,R,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
Expand Down
9 changes: 0 additions & 9 deletions DifferentiationInterface/src/second_order/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ function hvp! end

### Extras types

"""
HVPExtras
Abstract type for additional information needed by [`hvp`](@ref) and its variants.
"""
abstract type HVPExtras <: Extras end

struct NoHVPExtras <: HVPExtras end

struct ForwardOverForwardHVPExtras{G<:Gradient,E<:PushforwardExtras} <: HVPExtras
inner_gradient::G
outer_pushforward_extras::E
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ function value_derivative_and_second_derivative! end

## Preparation

"""
SecondDerivativeExtras
Abstract type for additional information needed by [`second_derivative`](@ref) and its variants.
"""
abstract type SecondDerivativeExtras <: Extras end

struct NoSecondDerivativeExtras <: SecondDerivativeExtras end

struct InnerDerivative{F,B}
f::F
backend::B
Expand Down
65 changes: 65 additions & 0 deletions DifferentiationInterface/src/utils/extras.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
abstract type Extras end

"""
PushforwardExtras
Abstract type for additional information needed by [`pushforward`](@ref) and its variants.
"""
abstract type PushforwardExtras <: Extras end
struct NoPushforwardExtras <: PushforwardExtras end

"""
PullbackExtras
Abstract type for additional information needed by [`pullback`](@ref) and its variants.
"""
abstract type PullbackExtras <: Extras end
struct NoPullbackExtras <: PullbackExtras end

"""
DerivativeExtras
Abstract type for additional information needed by [`derivative`](@ref) and its variants.
"""
abstract type DerivativeExtras <: Extras end
struct NoDerivativeExtras <: DerivativeExtras end

"""
GradientExtras
Abstract type for additional information needed by [`gradient`](@ref) and its variants.
"""
abstract type GradientExtras <: Extras end
struct NoGradientExtras <: GradientExtras end

"""
JacobianExtras
Abstract type for additional information needed by [`jacobian`](@ref) and its variants.
"""
abstract type JacobianExtras <: Extras end
struct NoJacobianExtras <: JacobianExtras end

"""
HVPExtras
Abstract type for additional information needed by [`hvp`](@ref) and its variants.
"""
abstract type HVPExtras <: Extras end
struct NoHVPExtras <: HVPExtras end

"""
HessianExtras
Abstract type for additional information needed by [`hessian`](@ref) and its variants.
"""
abstract type HessianExtras <: Extras end
struct NoHessianExtras <: HessianExtras end

"""
SecondDerivativeExtras
Abstract type for additional information needed by [`second_derivative`](@ref) and its variants.
"""
abstract type SecondDerivativeExtras <: Extras end
struct NoSecondDerivativeExtras <: SecondDerivativeExtras end

0 comments on commit 6c829fc

Please sign in to comment.