Skip to content

Commit

Permalink
fix accumulate tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 19, 2022
1 parent c7e7f13 commit 2b5877c
Showing 1 changed file with 10 additions and 35 deletions.
45 changes: 10 additions & 35 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,9 @@ end
end
end # cumprod

@testset "accumulate(f, ::Array)" begin
@testset "accumulate(f, ::Vector)" begin
# `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
# The rule is now attached there, as this is the simplest way to handle `init` keyword.
@eval using Base: _accumulate!

# Simple
y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
Expand All @@ -372,9 +371,9 @@ end
@test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
@test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented

y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
@test y2 accumulate(/, [1 2; 3 4])
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
# y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing)
# @test y2 ≈ accumulate(/, [1 2; 3 4.0])
# @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6

# Test execution order
c3 = Counter()
Expand Down Expand Up @@ -404,35 +403,11 @@ end
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string

# Finite differencing
test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
test_rrule(accumulate, /, 1 .+ rand(3, 4))
test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
# test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, Some(rand()))
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
test_rrule(_accumulate!, /, randn(4) NoTangent(), 1 .+ rand(4), nothing, nothing)
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
test_rrule(_accumulate!, ^, randn(6) NoTangent(), 1 .+ rand(6), nothing, Some(rand()))
end
@testset "accumulate(f, ::Tuple)" begin
# Simple
y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
@test y1 == (1, 2, 6, 24)
@test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))

# Finite differencing
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)

test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, nothing)
test_rrule(_accumulate!, /, randn(5) NoTangent(), randn(5), nothing, Some(1 + rand()))
# if VERSION >= v"1.5"
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
# end
end
# VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
# # Simple
# y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
# @test y1 == (1, 2, 6, 24)
# @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))

# # Finite differencing
# test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
# test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
# end
end

0 comments on commit 2b5877c

Please sign in to comment.