From 8cc63ccdfa61c3ef180483e96bf1e57459847a29 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 31 Dec 2024 10:20:00 +0100 Subject: [PATCH 1/2] fix: check nothing output for Zygote --- .../DifferentiationInterfaceZygoteExt.jl | 41 ++++++++++++-- .../test/Back/Zygote/test.jl | 53 +++++++++++-------- 2 files changed, 68 insertions(+), 26 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 0f681c9f9..d86f2ec89 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -6,6 +6,24 @@ using ForwardDiff: ForwardDiff using Zygote: ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian +struct ZygoteNothingError <: Exception + f + x + contexts +end + +function Base.showerror(io::IO, e::ZygoteNothingError) + (; f, x, contexts) = e + sig = (typeof(x), map(typeof ∘ DI.unwrap, contexts)...) + return print( + io, + "Zygote failed to differentiate function `$f` with argument types `$sig` (the pullback returned `nothing`).", + ) +end + +check_nothing(::Nothing, f, x, contexts) = throw(ZygoteNothingError(f, x, contexts)) +check_nothing(::Any, f, x, contexts) = nothing + DI.check_available(::AutoZygote) = true DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() @@ -46,6 +64,7 @@ function DI.value_and_pullback( tx = map(ty) do dy first(pb(dy)) end + check_nothing(first(tx), f, x, contexts) return y, tx end @@ -61,6 +80,7 @@ function DI.value_and_pullback( tx = map(ty) do dy first(pb(dy)) end + check_nothing(first(tx), f, x, contexts) return copy(y), tx end @@ -76,6 +96,7 @@ function DI.pullback( tx = map(ty) do dy first(pb(dy)) end + check_nothing(first(tx), f, x, contexts) return tx end @@ -95,6 +116,7 @@ function DI.value_and_gradient( contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(grad), f, x, contexts) return val, first(grad) end @@ -105,7 +127,9 @@ function DI.gradient( x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} - return first(gradient(f, x, map(DI.unwrap, contexts)...)) + grad = gradient(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(grad), f, x, contexts) + return first(grad) end function DI.value_and_gradient!( @@ -146,8 +170,11 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} - return f(x, map(DI.unwrap, contexts)...), - first(jacobian(f, x, map(DI.unwrap, contexts)...)) # https://github.com/FluxML/Zygote.jl/issues/1506 + y = f(x, map(DI.unwrap, contexts)...) + # https://github.com/FluxML/Zygote.jl/issues/1506 + jac = jacobian(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(jac), f, x, contexts) + return y, first(jac) end function DI.jacobian( @@ -157,7 +184,9 @@ function DI.jacobian( x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} - return first(jacobian(f, x, map(DI.unwrap, contexts)...)) + jac = jacobian(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(jac), f, x, contexts) + return first(jac) end function DI.value_and_jacobian!( @@ -266,7 +295,9 @@ function DI.hessian( contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} fc = DI.with_contexts(f, contexts...) - return hessian(fc, x) + hess = hessian(fc, x) + check_nothing(hess, f, x, contexts) + return hess end function DI.hessian!( diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 9204673d9..25c24a3eb 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -24,27 +24,38 @@ end ## Dense -test_differentiation( - backends, - default_scenarios(; include_constantified=true); - excluded=[:second_derivative], - logging=LOGGING, -); - -test_differentiation(second_order_backends; logging=LOGGING); - -test_differentiation( - backends[1], - vcat(component_scenarios(), gpu_scenarios()); - excluded=SECOND_ORDER, - logging=LOGGING, -) +@testset "Dense" begin + test_differentiation( + backends, + default_scenarios(; include_constantified=true); + excluded=[:second_derivative], + logging=LOGGING, + ) + + test_differentiation(second_order_backends; logging=LOGGING) + + test_differentiation( + backends[1], + vcat(component_scenarios(), gpu_scenarios()); + excluded=SECOND_ORDER, + logging=LOGGING, + ) +end ## Sparse -test_differentiation( - MyAutoSparse.(vcat(backends, second_order_backends)), - sparse_scenarios(; band_sizes=0:-1); - sparsity=true, - logging=LOGGING, -) +@testset "Sparse" begin + test_differentiation( + MyAutoSparse.(vcat(backends, second_order_backends)), + sparse_scenarios(; band_sizes=0:-1); + sparsity=true, + logging=LOGGING, + ) +end + +## Errors + +@testset "Errors" begin + safe_log(x) = x > zero(x) ? log(x) : convert(typeof(x), NaN) + @test_throws Exception derivative(safe_log, AutoZygote(), 0.0) +end From a9195738deb0a9326796a3e7844b76095296458a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 4 Jan 2025 19:57:16 +0100 Subject: [PATCH 2/2] Test error msg --- DifferentiationInterface/test/Back/Zygote/test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 25c24a3eb..e83e3c257 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -57,5 +57,5 @@ end @testset "Errors" begin safe_log(x) = x > zero(x) ? log(x) : convert(typeof(x), NaN) - @test_throws Exception derivative(safe_log, AutoZygote(), 0.0) + @test_throws "Zygote failed to differentiate" derivative(safe_log, AutoZygote(), 0.0) end