From e674fc49d26ed4921dfa7a8df50c7786174a3231 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 11 Jul 2024 10:41:40 +0100 Subject: [PATCH 1/3] drafts --- Project.toml | 2 +- ext/LogDensityProblemsADReverseDiffExt.jl | 6 ++++++ src/LogDensityProblemsAD.jl | 12 +++++++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 003e0a0..3b7694d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogDensityProblemsAD" uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1" authors = ["Tamás K. Papp "] -version = "1.9.0" +version = "2.0.0" [deps] DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" diff --git a/ext/LogDensityProblemsADReverseDiffExt.jl b/ext/LogDensityProblemsADReverseDiffExt.jl index d101a7b..07de1d1 100644 --- a/ext/LogDensityProblemsADReverseDiffExt.jl +++ b/ext/LogDensityProblemsADReverseDiffExt.jl @@ -50,6 +50,12 @@ function ADgradient(::Val{:ReverseDiff}, ℓ; ReverseDiffLogDensity(ℓ, _compiledtape(ℓ, compile, x)) end +function ADgradient(::Val{:ReverseDiff}, ∇ℓ::ADGradientWrapper; + compile::Union{Val{true},Val{false}}=Val(false), + x::Union{Nothing,AbstractVector}=nothing) + ADgradient(Val{:ReverseDiff}, ∇ℓ.ℓ; compile=compile, x=x) +end + _compiledtape(ℓ, compile, x) = nothing _compiledtape(ℓ, ::Val{true}, ::Nothing) = _compiledtape(ℓ, Val(true), zeros(dimension(ℓ))) function _compiledtape(ℓ, ::Val{true}, x) diff --git a/src/LogDensityProblemsAD.jl b/src/LogDensityProblemsAD.jl index c2bb78f..1cc47c6 100644 --- a/src/LogDensityProblemsAD.jl +++ b/src/LogDensityProblemsAD.jl @@ -3,7 +3,7 @@ Automatic differentiation backends for LogDensityProblems. """ module LogDensityProblemsAD -export ADgradient +export ADgradient, replace_ℓ using DocStringExtensions: SIGNATURES import LogDensityProblems: logdensity, logdensity_and_gradient, capabilities, dimension @@ -38,6 +38,16 @@ Base.copy(x::ADGradientWrapper) = x # no-op, except for ForwardDiff """ $(SIGNATURES) +Replace the log density in an `ADGradientWrapper`. +""" +function replace_ℓ(∇ℓ::ADGradientWrapper, new_ℓ) + @info "$(typeof(∇ℓ)) not supported for replacement" + throw(MethodError(replace_ℓ, (∇ℓ, new_ℓ))) +end + +""" +$(SIGNATURES) + Wrap `P` using automatic differentiation to obtain a gradient. `kind` is usually a `Val` type with a symbol that refers to a package, for example From 269ec8d51ff403c17a24a423335d3b4bccc6409d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 11 Jul 2024 10:57:54 +0100 Subject: [PATCH 2/3] =?UTF-8?q?add=20`replace=5F=E2=84=93`=20for=20Reverse?= =?UTF-8?q?DiffExt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ext/LogDensityProblemsADReverseDiffExt.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ext/LogDensityProblemsADReverseDiffExt.jl b/ext/LogDensityProblemsADReverseDiffExt.jl index 07de1d1..500b5c5 100644 --- a/ext/LogDensityProblemsADReverseDiffExt.jl +++ b/ext/LogDensityProblemsADReverseDiffExt.jl @@ -49,13 +49,19 @@ function ADgradient(::Val{:ReverseDiff}, ℓ; x::Union{Nothing,AbstractVector}=nothing) ReverseDiffLogDensity(ℓ, _compiledtape(ℓ, compile, x)) end - function ADgradient(::Val{:ReverseDiff}, ∇ℓ::ADGradientWrapper; compile::Union{Val{true},Val{false}}=Val(false), x::Union{Nothing,AbstractVector}=nothing) ADgradient(Val{:ReverseDiff}, ∇ℓ.ℓ; compile=compile, x=x) end +function LogDensityProblemsAD.replace_ℓ(∇ℓ::ReverseDiffLogDensity{L,C}, new_ℓ) + ReverseDiffLogDensity(new_ℓ, _compiledtape(new_ℓ, Val(true), nothing)) +end +function LogDensityProblemsAD.replace_ℓ(∇ℓ::ReverseDiffLogDensity{L,Nothing}, new_ℓ) + ReverseDiffLogDensity(new_ℓ, _compiledtape(new_ℓ, Val(false), nothing)) +end + _compiledtape(ℓ, compile, x) = nothing _compiledtape(ℓ, ::Val{true}, ::Nothing) = _compiledtape(ℓ, Val(true), zeros(dimension(ℓ))) function _compiledtape(ℓ, ::Val{true}, x) From 08976294ff5666f735e316fb164af6233f795a72 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 12 Jul 2024 13:56:14 +0100 Subject: [PATCH 3/3] apply suggestions from @tpapp --- Project.toml | 4 +++- src/LogDensityProblemsAD.jl | 8 +++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 3b7694d..ed58ae6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "LogDensityProblemsAD" uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1" authors = ["Tamás K. Papp "] -version = "2.0.0" +version = "1.9.0" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -31,6 +32,7 @@ LogDensityProblemsADZygoteExt = "Zygote" [compat] ADTypes = "0.1.7, 0.2, 1" +Compat = "4.15.0" DocStringExtensions = "0.8, 0.9" Enzyme = "0.11, 0.12" FiniteDifferences = "0.12" diff --git a/src/LogDensityProblemsAD.jl b/src/LogDensityProblemsAD.jl index 1cc47c6..94a6a7f 100644 --- a/src/LogDensityProblemsAD.jl +++ b/src/LogDensityProblemsAD.jl @@ -3,7 +3,8 @@ Automatic differentiation backends for LogDensityProblems. """ module LogDensityProblemsAD -export ADgradient, replace_ℓ +export ADgradient +@compat public replace_ℓ using DocStringExtensions: SIGNATURES import LogDensityProblems: logdensity, logdensity_and_gradient, capabilities, dimension @@ -40,10 +41,7 @@ $(SIGNATURES) Replace the log density in an `ADGradientWrapper`. """ -function replace_ℓ(∇ℓ::ADGradientWrapper, new_ℓ) - @info "$(typeof(∇ℓ)) not supported for replacement" - throw(MethodError(replace_ℓ, (∇ℓ, new_ℓ))) -end +function replace_ℓ(∇ℓ::ADGradientWrapper, new_ℓ) end """ $(SIGNATURES)