From 40d6d4d4b2655c223b18c4530734dba354667c52 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 4 Dec 2024 13:16:08 +0100 Subject: [PATCH] Fix batch size --- DifferentiationInterface/Project.toml | 2 +- .../test/Back/ForwardDiff/test.jl | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index e0cb9fec2..b4d478ab2 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.24" +version = "0.6.25" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 00ed7c5eb..5c4758fdd 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -4,9 +4,10 @@ Pkg.add("ForwardDiff") using ADTypes: ADTypes using ComponentArrays: ComponentArrays using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterface as DI import DifferentiationInterfaceTest as DIT using ForwardDiff: ForwardDiff -using StaticArrays: StaticArrays +using StaticArrays: StaticArrays, @SVector using Test using ExplicitImports @@ -75,7 +76,18 @@ test_differentiation( test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) -@testset verbose = true "No allocations on StaticArrays" begin +@testset verbose = true "StaticArrays" begin + @testset "Batch size" begin + @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7} + @test DI.pick_batchsize(AutoForwardDiff(; chunksize=5), rand(7)) isa + DI.BatchSizeSettings{5} + @test (@inferred DI.pick_batchsize(AutoForwardDiff(), @SVector(rand(7)))) isa + DI.BatchSizeSettings{7} + @test (@inferred DI.pick_batchsize( + AutoForwardDiff(; chunksize=5), @SVector(rand(7)) + )) isa DI.BatchSizeSettings{5} + end + filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out end