From b362324c020f1462343e46f1c41af97760950414 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Dec 2024 07:15:58 -0500 Subject: [PATCH] fix: handle debug leafs with dispatch (#1115) * fix: handle debug leafs with dispatch * test: add a test for pooling layers * fix: pass in exclude to layer_map --- src/Lux.jl | 2 +- src/contrib/contrib.jl | 3 ++- src/contrib/debug.jl | 10 ++++++---- src/contrib/map.jl | 8 +++++++- src/layers/pooling.jl | 28 ++++++++++++++-------------- test/Project.toml | 2 +- test/contrib/debug_tests.jl | 17 +++++++++++++++++ 7 files changed, 48 insertions(+), 22 deletions(-) diff --git a/src/Lux.jl b/src/Lux.jl index 1b99492bb5..29014cfa6f 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -10,7 +10,7 @@ using Compat: @compat using ConcreteStructs: @concrete using EnzymeCore: EnzymeRules using FastClosures: @closure -using Functors: Functors, fmap +using Functors: Functors, KeyPath, fmap using GPUArraysCore: @allowscalar using Markdown: @doc_str using NNlib: NNlib diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index a4b07170b8..1b808b319a 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -15,7 +15,8 @@ using Static: StaticSymbol, StaticBool, True, known, static, dynamic using ..Lux: Lux, Optional using ..Utils: Utils, BoolType, SymbolType -using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer, apply +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer, + AbstractLuxContainerLayer, apply const CRC = ChainRulesCore diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index 38c13f36fe..e4c475d65e 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -141,8 +141,10 @@ Recurses into the `layer` and replaces the inner most non Container Layers with See [`Lux.Experimental.DebugLayer`](@ref) for details about the Keyword Arguments. """ macro debug_mode(layer, kwargs...) - kws = esc.(kwargs) - return :($(fmap_with_path)( - (kp, l) -> DebugLayer(l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kws...)), - $(esc(layer)))) + return esc(:( + $(fmap_with_path)( + (kp, l) -> $(DebugLayer)( + l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kwargs...)), + $(layer); exclude=$(layer_map_leaf)) + )) end diff --git a/src/contrib/map.jl b/src/contrib/map.jl index f5142f0db9..efd6acf29a 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -57,7 +57,8 @@ true ``` """ function layer_map(f, l, ps, st) - return fmap_with_path(l, ps, st; walk=LayerWalkWithPath()) do kp, layer, ps_, st_ + return fmap_with_path( + l, ps, st; walk=LayerWalkWithPath(), exclude=layer_map_leaf) do kp, layer, ps_, st_ return f(layer, ps_, st_, kp) end end @@ -103,3 +104,8 @@ function perform_layer_map(recurse, kp, ps_children, st_children, layer_children return layer_children_new, ps_children_new, st_children_new end + +layer_map_leaf(::KeyPath, ::AbstractLuxLayer) = true +layer_map_leaf(::KeyPath, ::AbstractLuxWrapperLayer) = false +layer_map_leaf(::KeyPath, ::AbstractLuxContainerLayer) = false +layer_map_leaf(::KeyPath, x) = Functors.isleaf(x) diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 819aaaeebf..943eb947c1 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -197,6 +197,8 @@ for layer_op in (:Max, :Mean, :LP) layer <: PoolingLayer end + Experimental.layer_map_leaf(::KeyPath, ::$(layer_name)) = true + function $(layer_name)( window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2) return $(layer_name)(PoolingLayer(static(:generic), static($(Meta.quot(op))), @@ -204,19 +206,15 @@ for layer_op in (:Max, :Mean, :LP) end function Base.show(io::IO, m::$(layer_name)) - kernel_size = m.layer.mode.kernel_size + (; mode, op) = m.layer + (; kernel_size, pad, stride, dilation) = mode print(io, string($(Meta.quot(layer_name))), "($(kernel_size)") - pad = m.layer.mode.pad all(==(0), pad) || print(io, ", pad=", PrettyPrinting.tuple_string(pad)) - stride = m.layer.mode.stride stride == kernel_size || print(io, ", stride=", PrettyPrinting.tuple_string(stride)) - dilation = m.layer.mode.dilation all(==(1), dilation) || print(io, ", dilation=", PrettyPrinting.tuple_string(dilation)) - if $(Meta.quot(op)) == :lp - m.layer.op.p == 2 || print(io, ", p=", m.layer.op.p) - end + $(Meta.quot(op)) == :lp && (op.p == 2 || print(io, ", p=", op.p)) print(io, ")") end @@ -228,15 +226,16 @@ for layer_op in (:Max, :Mean, :LP) layer <: PoolingLayer end + Experimental.layer_map_leaf(::KeyPath, ::$(global_layer_name)) = true + function $(global_layer_name)(; p=2) return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p)) end function Base.show(io::IO, g::$(global_layer_name)) + (; op) = g.layer print(io, string($(Meta.quot(global_layer_name))), "(") - if $(Meta.quot(op)) == :lp - g.layer.op.p == 2 || print(io, ", p=", g.layer.op.p) - end + $(Meta.quot(op)) == :lp && (op.p == 2 || print(io, ", p=", op.p)) print(io, ")") end @@ -248,16 +247,17 @@ for layer_op in (:Max, :Mean, :LP) layer <: PoolingLayer end + Experimental.layer_map_leaf(::KeyPath, ::$(adaptive_layer_name)) = true + function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2) return $(adaptive_layer_name)(PoolingLayer( static(:adaptive), $(Meta.quot(op)), out_size; p)) end function Base.show(io::IO, a::$(adaptive_layer_name)) - print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size) - if $(Meta.quot(op)) == :lp - a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) - end + (; mode, op) = a.layer + print(io, string($(Meta.quot(adaptive_layer_name))), "(", mode.out_size) + $(Meta.quot(op)) == :lp && (op.p == 2 || print(io, ", p=", op.p)) print(io, ")") end diff --git a/test/Project.toml b/test/Project.toml index aca27bdbf3..429ba85c5a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -77,7 +77,7 @@ SimpleChains = "0.4.7" StableRNGs = "1.0.2" Static = "1" StaticArrays = "1.9" -Statistics = "1.11.1" +Statistics = "1.10" Test = "1.10" Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/test/contrib/debug_tests.jl b/test/contrib/debug_tests.jl index 6de0618743..a8b0109450 100644 --- a/test/contrib/debug_tests.jl +++ b/test/contrib/debug_tests.jl @@ -151,3 +151,20 @@ end @test !any(isnan, gs.layer_3.bias) end end + +@testitem "Debugging Tools: Issue #1068" setup=[SharedTestSetup] tags=[:misc] begin + model = Chain( + Conv((3, 3), 3 => 16, relu; stride=2), + MaxPool((2, 2)), + AdaptiveMaxPool((2, 2)), + GlobalMaxPool() + ) + + model_debug = Lux.Experimental.@debug_mode model + display(model_debug) + + @test model_debug[1] isa Lux.Experimental.DebugLayer + @test model_debug[2] isa Lux.Experimental.DebugLayer + @test model_debug[3] isa Lux.Experimental.DebugLayer + @test model_debug[4] isa Lux.Experimental.DebugLayer +end