diff --git a/Project.toml b/Project.toml index 4de0239..ff33a90 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogDensityProblems" uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" authors = ["Tamas K. Papp "] -version = "1.0.2" +version = "1.0.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/AD_ForwardDiff.jl b/src/AD_ForwardDiff.jl index 3b20070..aab4522 100644 --- a/src/AD_ForwardDiff.jl +++ b/src/AD_ForwardDiff.jl @@ -45,7 +45,7 @@ end function logdensity_and_gradient(fℓ::ForwardDiffLogDensity, x::AbstractVector) @unpack ℓ, gradientconfig = fℓ - buffer = _diffresults_buffer(ℓ, x) + buffer = _diffresults_buffer(x) result = ForwardDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x, gradientconfig) _diffresults_extract(result) end diff --git a/src/AD_ReverseDiff.jl b/src/AD_ReverseDiff.jl index 07e0b10..91f07ef 100644 --- a/src/AD_ReverseDiff.jl +++ b/src/AD_ReverseDiff.jl @@ -50,7 +50,7 @@ end function logdensity_and_gradient(∇ℓ::ReverseDiffLogDensity, x::AbstractVector) @unpack ℓ, compiledtape = ∇ℓ - buffer = _diffresults_buffer(ℓ, x) + buffer = _diffresults_buffer(x) if compiledtape === nothing result = ReverseDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x) else diff --git a/src/DiffResults_helpers.jl b/src/DiffResults_helpers.jl index 9ee4e1a..3624ef4 100644 --- a/src/DiffResults_helpers.jl +++ b/src/DiffResults_helpers.jl @@ -11,10 +11,10 @@ $(SIGNATURES) Allocate a DiffResults buffer for a gradient, taking the element type of `x` into account (heuristically). """ -function _diffresults_buffer(ℓ, x) +function _diffresults_buffer(x) T = eltype(x) S = T <: Real ? float(Real) : Float64 # heuristic - DiffResults.MutableDiffResult(zero(S), (Vector{S}(undef, dimension(ℓ)), )) + DiffResults.MutableDiffResult(zero(S), (similar(x, S), )) end """ @@ -25,5 +25,5 @@ constructed with [`diffresults_buffer`](@ref). Gradient is not copied as caller vector. """ function _diffresults_extract(diffresult::DiffResults.DiffResult) - DiffResults.value(diffresult)::Real, DiffResults.gradient(diffresult) + DiffResults.value(diffresult), DiffResults.gradient(diffresult) end