Skip to content

Commit

Permalink
Fix batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Dec 4, 2024
1 parent 01bfe8a commit c4f2d52
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
function DI.BatchSizeSettings(::AutoEnzyme, N::Integer)
B = DI.reasonable_batchsize(N, 16)
singlebatch = B == N
aligned = N % B == 0
return DI.BatchSizeSettings{B,singlebatch,aligned}(N)
return DI.BatchSizeSettings{B}(N)
end

to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, N::Integer)
B = ForwardDiff.pickchunksize(N)
singlebatch = B == N
aligned = N % B == 0
return DI.BatchSizeSettings{B,singlebatch,aligned}(N)
chunksize = ForwardDiff.pickchunksize(N)
return DI.BatchSizeSettings{chunksize}(N)
end

function DI.BatchSizeSettings(::AutoForwardDiff{chunksize}, N::Integer) where {chunksize}
if chunksize > N
throw(ArgumentError("Fixed chunksize $chunksize larger than input size $N"))
end
B = chunksize
singlebatch = B == N
aligned = N % B == 0
return DI.BatchSizeSettings{B,singlebatch,aligned}(N)
return DI.BatchSizeSettings{chunksize}(N)

Check warning on line 7 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl#L7

Added line #L7 was not covered by tests
end

function DI.threshold_batchsize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ end

DI.ismutable_array(::Type{<:SArray}) = false

function DI.BatchSizeSettings(::DI.AutoSimpleFiniteDiff{nothing}, x::StaticArray)
return DI.BatchSizeSettings{length(x),true,true}(length(x))
end

function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, x::StaticArray)
return DI.BatchSizeSettings{length(x),true,true}(length(x))
end
Expand All @@ -22,4 +26,16 @@ function DI.BatchSizeSettings(::AutoEnzyme, x::StaticArray)
return DI.BatchSizeSettings{length(x),true,true}(length(x))
end

function DI.BatchSizeSettings(
::DI.AutoSimpleFiniteDiff{chunksize}, x::StaticArray
) where {chunksize}
return DI.BatchSizeSettings{chunksize}(Val(length(x)))
end

function DI.BatchSizeSettings(

Check warning on line 35 in DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl#L35

Added line #L35 was not covered by tests
::AutoForwardDiff{chunksize}, x::StaticArray
) where {chunksize}
return DI.BatchSizeSettings{chunksize}(Val(length(x)))

Check warning on line 38 in DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl#L38

Added line #L38 was not covered by tests
end

end
10 changes: 2 additions & 8 deletions DifferentiationInterface/src/misc/simple_finite_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,11 @@ inplace_support(::AutoSimpleFiniteDiff) = InPlaceSupported()

function BatchSizeSettings(::AutoSimpleFiniteDiff{nothing}, N::Integer)
B = reasonable_batchsize(N, 12)
singlebatch = B == N
aligned = N % B == 0
return BatchSizeSettings{B,singlebatch,aligned}(N)
return BatchSizeSettings{B}(N)
end

function BatchSizeSettings(::AutoSimpleFiniteDiff{chunksize}, N::Integer) where {chunksize}
@assert chunksize <= N
B = chunksize
singlebatch = B == N
aligned = N % B == 0
return BatchSizeSettings{B,singlebatch,aligned}(N)
return BatchSizeSettings{chunksize}(N)
end

function threshold_batchsize(
Expand Down
16 changes: 15 additions & 1 deletion DifferentiationInterface/src/utils/batchsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Configuration for the batch size deduced from a backend and a sample array of le
# Type parameters
- `B::Int`: batch size
- `singlebatch::Bool`: whether `B > N`
- `singlebatch::Bool`: whether `B == N` (`B > N` is not allowed)
- `aligned::Bool`: whether `N % B == 0`
# Fields
Expand All @@ -22,11 +22,25 @@ struct BatchSizeSettings{B,singlebatch,aligned}
end

function BatchSizeSettings{B,singlebatch,aligned}(N::Integer) where {B,singlebatch,aligned}
B > N && throw(ArgumentError("Batch size $B larger than input size $N"))
A = div(N, B, RoundUp)
B_last = N % B
return BatchSizeSettings{B,singlebatch,aligned}(N, A, B_last)
end

function BatchSizeSettings{B}(::Val{N}) where {B,N}
singlebatch = B == N
aligned = N % B == 0
return BatchSizeSettings{B,singlebatch,aligned}(N)
end

function BatchSizeSettings{B}(N::Integer) where {B}
# type-unstable
singlebatch = B == N
aligned = N % B == 0
return BatchSizeSettings{B,singlebatch,aligned}(N)
end

function BatchSizeSettings(::AbstractADType, N::Integer)
B = 1
singlebatch = false
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/test/Core/Internals/batchsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ BSS = BatchSizeSettings
)
end

@testset "ForwardDiff (adaptive)" begin
@testset "SimpleFiniteDiff (adaptive)" begin
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(2))) isa BSS{2,true,true}
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(6))) isa BSS{6,true,true}
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(12))) isa BSS{12,true,true}
Expand All @@ -38,7 +38,7 @@ end
BSS{100,true,true}
end

@testset "ForwardDiff (fixed)" begin
@testset "SimpleFiniteDiff (fixed)" begin
@test_throws ArgumentError pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(2))
@test_throws ArgumentError pick_batchsize(
AutoSimpleFiniteDiff(; chunksize=4), @SVector(zeros(2))
Expand Down

0 comments on commit c4f2d52

Please sign in to comment.