Skip to content

Commit

Permalink
Some small printing upgrades (#2344)
Browse files Browse the repository at this point in the history
* some printing upgrades

* print eltype too

* move one line to solve order-of-loading issue

* better fix

* tests, and Fix1
  • Loading branch information
mcabbott authored Oct 26, 2024
1 parent c9bab66 commit 93e1de7
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.14.22"
version = "0.14.23"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
73 changes: 68 additions & 5 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ function _macro_big_show(ex)
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")")
pre, post = _show_pre_post(obj)
children = _show_children(obj)
if all(_show_leaflike, children)
# This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids,
# but once all layers use @layer, they stop the recursion by defining a method for _big_show.
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
# then we insert names -- can this be done more generically?
println(io, " "^indent, isnothing(name) ? "" : "$name = ", pre)
if obj isa Chain{<:NamedTuple} || obj isa NamedTuple
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
end
Expand All @@ -52,6 +52,20 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end

for Fix in (:Fix1, :Fix2)
pre = string(Fix, "(")
@eval function _big_show(io::IO, obj::Base.$Fix, indent::Int=0, name=nothing)
println(io, " "^indent, isnothing(name) ? "" : "$name = ", $pre)
_big_show(io, obj.f, indent+2)
_big_show(io, obj.x, indent+2)
println(io, " "^indent, ")", ",")
end
end

_show_pre_post(obj) = string(nameof(typeof(obj)), "("), ")"
_show_pre_post(::AbstractVector) = "[", "]"
_show_pre_post(::NamedTuple) = "(;", ")"

_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:

# note the covariance of tuple, using <:T causes warning or error
Expand Down Expand Up @@ -88,7 +102,7 @@ end

function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * sprint(show, layer, context=io)
str = _str * _layer_string(io, layer)
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(params(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
Expand All @@ -103,6 +117,15 @@ color=:light_black)
indent==0 || println(io)
end

_layer_string(io::IO, layer) = sprint(show, layer, context=io)
# _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray
function _layer_string(::IO, a::AbstractArray)
full = string(typeof(a))
comma = findfirst(',', full)
short = isnothing(comma) ? full : full[1:comma] * "...}"
Base.dims2string(size(a)) * " " * short
end

function _big_finale(io::IO, m)
ps = params(m)
if length(ps) > 2
Expand Down Expand Up @@ -150,3 +173,43 @@ _any(f, x::Number) = f(x)
# _any(f, x) = false

_all(f, xs) = !_any(!f, xs)

#=
julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
# Before, notice Array(), NamedTuple(), and values
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
Chain(
Tmp2(
Array(
Dense(2 => 3), # 9 parameters
[0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
),
NamedTuple(
1:3, # 3 parameters
Dense(3 => 4), # 16 parameters
[0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
),
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
# After, (; x=, y=, z=) and "3-element Array"
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
Chain(
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint, # 12 parameters
],
(;
x = 3-element UnitRange, # 3 parameters
y = Dense(3 => 4), # 16 parameters
z = 3-element Array, # 3 parameters
),
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
=#
8 changes: 7 additions & 1 deletion test/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ end
# [email protected] marks transposed matrices non-leaf, shouldn't affect printing:
adjoint_chain = repr("text/plain", Chain([Dense([1 2; 3 4]')]))
@test occursin("Dense(2 => 2)", adjoint_chain)
@test occursin("Chain([", adjoint_chain)
@test occursin("Chain(", adjoint_chain)
@test occursin("[", adjoint_chain)

# New printing of arrays, and Fix1
fix_chain = repr("text/plain", Chain(Base.Fix1(*, rand32(22,33)), softmax))
@test occursin("Fix1(", fix_chain)
@test occursin("22×33 Matrix{Float32}", fix_chain)
end

# Bug when no children, https://github.com/FluxML/Flux.jl/issues/2208
Expand Down

0 comments on commit 93e1de7

Please sign in to comment.