Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: get rid of implicit imports and clarify extension imports #649

Merged
merged 25 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.23"
version = "0.6.24"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
Expand All @@ -32,10 +33,10 @@ DifferentiationInterfaceEnzymeExt = "Enzyme"
DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceStaticArraysExt = "StaticArrays"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ using ChainRulesCore:
frule_via_ad,
rrule_via_ad
import DifferentiationInterface as DI
using DifferentiationInterface:
Constant, DifferentiateWith, NoPullbackPrep, NoPushforwardPrep, PullbackPrep, unwrap

ruleconfig(backend::AutoChainRules) = backend.ruleconfig

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function ChainRulesCore.rrule(dw::DifferentiateWith, x)
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
## Pullback

struct ChainRulesPullbackPrepSamePoint{Y,PB} <: PullbackPrep
struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
y::Y
pb::PB
end

function DI.prepare_pullback(
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{Constant,C}
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.Constant,C}
) where {C}
return NoPullbackPrep()
return DI.NoPullbackPrep()
end

function DI.prepare_pullback_same_point(
f,
::NoPullbackPrep,
::DI.NoPullbackPrep,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
return ChainRulesPullbackPrepSamePoint(y, pb)
end

function DI.value_and_pullback(
f,
::NoPullbackPrep,
::DI.NoPullbackPrep,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
tx = map(ty) do dy
pb(dy)[2]
end
Expand All @@ -46,7 +46,7 @@ function DI.value_and_pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
(; y, pb) = prep
tx = map(ty) do dy
Expand All @@ -61,7 +61,7 @@ function DI.pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
(; pb) = prep
tx = map(ty) do dy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module DifferentiationInterfaceDiffractorExt

using ADTypes: ADTypes, AutoDiffractor
import DifferentiationInterface as DI
using DifferentiationInterface: NoPushforwardPrep
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆

DI.check_available(::AutoDiffractor) = true
Expand All @@ -11,9 +10,9 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()

## Pushforward

DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = NoPushforwardPrep()
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep()

function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
ty = map(tx) do dx
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
Expand All @@ -24,7 +23,7 @@ function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
end

function DI.value_and_pushforward(
f, prep::NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
)
return f(x), DI.pushforward(f, prep, backend, x, tx)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,6 @@ module DifferentiationInterfaceEnzymeExt
using ADTypes: ADTypes, AutoEnzyme
using Base: Fix1
import DifferentiationInterface as DI
using DifferentiationInterface:
Context,
DerivativePrep,
GradientPrep,
JacobianPrep,
HVPPrep,
PullbackPrep,
PushforwardPrep,
NoDerivativePrep,
NoGradientPrep,
NoHVPPrep,
NoJacobianPrep,
NoPullbackPrep,
NoPushforwardPrep
using Enzyme:
Active,
Annotation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ function DI.prepare_pushforward(
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
return NoPushforwardPrep()
return DI.NoPushforwardPrep()
end

function DI.value_and_pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
dx_sametype = convert(typeof(x), only(tx))
Expand All @@ -29,11 +29,11 @@ end

function DI.value_and_pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = get_f_and_df(f, backend, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
Expand All @@ -46,11 +46,11 @@ end

function DI.pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
dx_sametype = convert(typeof(x), only(tx))
Expand All @@ -63,11 +63,11 @@ end

function DI.pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = get_f_and_df(f, backend, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
Expand All @@ -81,11 +81,11 @@ end
function DI.value_and_pushforward!(
f::F,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
# dy cannot be passed anyway
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
Expand All @@ -96,11 +96,11 @@ end
function DI.pushforward!(
f::F,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
# dy cannot be passed anyway
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
Expand All @@ -110,7 +110,7 @@ end

## Gradient

struct EnzymeForwardGradientPrep{B,O} <: GradientPrep
struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
shadows::O
end

Expand Down Expand Up @@ -175,7 +175,7 @@ end

## Jacobian

struct EnzymeForwardOneArgJacobianPrep{B,O} <: JacobianPrep
struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
shadows::O
output_length::Int
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ function DI.prepare_pushforward(
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
return NoPushforwardPrep()
return DI.NoPushforwardPrep()
end

function DI.value_and_pushforward(
f!::F,
y,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
f!_and_df! = get_f_and_df(f!, backend)
dx_sametype = convert(typeof(x), only(tx))
Expand All @@ -39,11 +39,11 @@ end
function DI.value_and_pushforward(
f!::F,
y,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f!_and_df! = get_f_and_df(f!, backend, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
Expand All @@ -64,11 +64,11 @@ end
function DI.pushforward(
f!::F,
y,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
_, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
return ty
Expand All @@ -78,11 +78,11 @@ function DI.value_and_pushforward!(
f!::F,
y,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
y, new_ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
Expand All @@ -93,11 +93,11 @@ function DI.pushforward!(
f!::F,
y,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
new_ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
Expand Down
Loading