Skip to content

Commit

Permalink
perf: check mutability of array before preallocating dual buffer (#619)
Browse files Browse the repository at this point in the history
* perf: check mutability of array before preallocating ForwardDiff dual buffer

* Add allocation testing
  • Loading branch information
gdalle authored Nov 10, 2024
1 parent 9a524d3 commit d537c45
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.22"
version = "0.6.23"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ using DifferentiationInterface:
outer,
shuffled_gradient,
unwrap,
with_contexts
with_contexts,
ismutable_array
import ForwardDiff.DiffResults as DR
using ForwardDiff.DiffResults:
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ function DI.prepare_pushforward(
f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C}
) where {F,C}
T = tag_type(f, backend, x)
xdual_tmp = make_dual_similar(T, x, tx)
if ismutable_array(x)
xdual_tmp = make_dual_similar(T, x, tx)
else
xdual_tmp = nothing
end
return ForwardDiffOneArgPushforwardPrep{T,typeof(xdual_tmp)}(xdual_tmp)
end

Expand All @@ -92,8 +96,12 @@ function compute_ydual_onearg(
tx::NTuple{B},
contexts::Vararg{Context,C},
) where {F,T,B,C}
(; xdual_tmp) = prep
make_dual!(T, xdual_tmp, x, tx)
if ismutable_array(x)
make_dual!(T, prep.xdual_tmp, x, tx)
xdual_tmp = prep.xdual_tmp
else
xdual_tmp = make_dual(T, x, tx)
end
contexts_dual = translate(T, Val(B), contexts...)
ydual = f(xdual_tmp, contexts_dual...)
return ydual
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ function DI.stack_vec_row(t::NTuple{B,<:StaticArray}) where {B}
return vcat(transpose.(map(vec, t))...)
end

DI.ismutable_array(::Type{<:SArray}) = false

function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, x::StaticArray)
return BatchSizeSettings{length(x),true,true}(length(x))
end
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
stack_vec_col(t::NTuple) = stack(vec, t; dims=2)
stack_vec_row(t::NTuple) = stack(vec, t; dims=1)

ismutable_array(::Type) = true
ismutable_array(x) = ismutable_array(typeof(x))
17 changes: 17 additions & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Pkg.add("ForwardDiff")

using ComponentArrays: ComponentArrays
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using ForwardDiff: ForwardDiff
using StaticArrays: StaticArrays
using Test
Expand Down Expand Up @@ -65,3 +66,19 @@ test_differentiation(
## Static

test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)

@testset verbose = true "No allocations on StaticArrays" begin
filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen
DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out
end
data = benchmark_differentiation(
AutoForwardDiff(),
filtered_static_scenarios;
benchmark=:prepared,
excluded=[:hessian, :pullback], # TODO: figure this out
logging=LOGGING,
)
@testset "$(row[:scenario])" for row in eachrow(data)
@test row[:allocs] == 0
end
end;

2 comments on commit d537c45

@gdalle
Copy link
Member Author

@gdalle gdalle commented on d537c45 Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "File (Julia)Project.toml not found"

Please sign in to comment.