Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack overflow from recursive function in DynamicExpressions.jl #428

Open
MilesCranmer opened this issue Dec 19, 2024 · 4 comments
Open
Labels
enhancement (error messages) The error was produced that should be improved upon

Comments

@MilesCranmer
Copy link

Hey all,

Thanks for working on this! I'm trying out Mooncake 0.4.65 on DynamicExpressions.jl 1.8.0 (via DifferentiationInterface 0.6.27) and ran into a stack overflow from this example:

import Mooncake
using DynamicExpressions
using DifferentiationInterface

# Build up expression:
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin))
variable_names = ["x1", "x2", "x3"]
x1, x2, x3 = map(i -> Expression(Node{Float64}(; feature=i); operators, variable_names), 1:3)
f = x1 + cos(x2 - 0.2)

eval_sum = let f = f
    X -> sum(f(X)[1])
end
backend = AutoMooncake(; config=nothing)

# Example data
X = randn(3, 100)
dX = gradient(f, backend, X)

This hits the following error:

ERROR: LoadError: MooncakeRuleCompilationError: an error occured while Mooncake was compiling a rule to differentiate something. If the `caused by` error message below does not make it clear to you how the problem can be fixed, please open an issue at github.com/compintell/Mooncake.jl describing your problem.
To replicate this error run the following:

Mooncake.build_rrule(Mooncake.MooncakeInterpreter(), Tuple{Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, Matrix{Float64}}; debug_mode=false)

Note that you may need to `using` some additional packages if not all of the names printed in the above signature are available currently in your environment.

Stacktrace:
 [1] build_rrule(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
   @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1074
 [2] build_rrule
   @ ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1017 [inlined]
 [3] prepare_pullback_cache(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
   @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:191
 [4] prepare_pullback_cache
   @ ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:185 [inlined]
 [5] prepare_pullback(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64}, ::Tuple{Bool})
   @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/gjT8p/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:10
 [6] prepare_gradient(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/gradient.jl:70
 [7] gradient(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/fallbacks/no_prep.jl:48
 [8] top-level scope
   @ ~/PermaDocuments/SymbolicRegressionMonorepo/DynamicExpressions.jl/test/test_mooncake.jl:29
in expression starting at /Users/mcranmer/PermaDocuments/SymbolicRegressionMonorepo/DynamicExpressions.jl/test/test_mooncake.jl:29

caused by: StackOverflowError:
Stacktrace:
     [1] _stable_typeof
       @ ./operators.jl:929 [inlined]
     [2] Base.Fix1(f::typeof(Mooncake.tangent_field_type), x::Type)
       @ Base ./operators.jl:1123
     [3] tangent_field_types
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:445 [inlined]
     [4] #s11#44
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:437 [inlined]
     [5] var"#s11#44"(P::Any, ::Any, ::Any)
       @ Mooncake ./none:0
     [6] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
       @ Core ./boot.jl:707
     [7] tangent_field_type(::Type{Node{Float64}}, n::Int64)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:461
     [8] Fix1
       @ ./operators.jl:1127 [inlined]
     [9] tuple_map(f::Base.Fix1{typeof(Mooncake.tangent_field_type), Type{Node{Float64}}}, x::NTuple{7, Int64})
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/utils.jl:41
--- the above 7 lines are repeated 5435 more times ---
 [38055] tangent_field_types
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:445 [inlined]
 [38056] #s11#44
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:437 [inlined]
 [38057] var"#s11#44"(P::Any, ::Any, ::Any)
       @ Mooncake ./none:0
 [38058] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
       @ Core ./boot.jl:707
 [38059] #s11#110
       @ ~/.julia/packages/Mooncake/N9iX9/src/fwds_rvs_data.jl:560 [inlined]
 [38060] var"#s11#110"(P::Any, ::Any, ::Any)
       @ Mooncake ./none:0
 [38061] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
       @ Core ./boot.jl:707
 [38062] lazy_zero_rdata_type(::Type{Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}})
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/fwds_rvs_data.jl:751
 [38063] call_composed
       @ ./operators.jl:1053 [inlined]
 [38064] (::ComposedFunction{typeof(Mooncake.lazy_zero_rdata_type), typeof(Mooncake._type)})(x::Type; kw::@Kwargs{})
       @ Base ./operators.jl:1050
 [38065] iterate
       @ ./generator.jl:48 [inlined]
 [38066] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, ComposedFunction{typeof(Mooncake.lazy_zero_rdata_type), typeof(Mooncake._type)}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
       @ Base ./array.jl:811
 [38067] collect_similar
       @ ./array.jl:720 [inlined]
 [38068] map
       @ ./abstractarray.jl:3371 [inlined]
 [38069] Mooncake.ADInfo(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, ir::Mooncake.BBCode, debug_mode::Bool)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:175
 [38070] generate_ir(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, do_inline::Bool)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1104
 [38071] generate_ir
       @ ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1087 [inlined]
 [38072] build_rrule(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1050
 [38073] build_rrule
       @ ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1017 [inlined]
 [38074] prepare_pullback_cache(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:191
 [38075] prepare_pullback_cache
       @ ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:185 [inlined]
 [38076] prepare_pullback(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64}, ::Tuple{Bool})
       @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/gjT8p/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:10
 [38077] prepare_gradient(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
       @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/gradient.jl:70

I'm assuming this is from the recursive evaluation of DynamicExpressions.jl, which gets branched from here: https://github.com/SymbolicML/DynamicExpressions.jl/blob/dde92915df3ed275989e53d3691fd7f9280d9b14/src/Evaluate.jl#L242-L269. I think it should be possible to make this work since Enzyme.jl can now differentiate it. Zygote.jl can't because there is array mutation, however.

@willtebbutt
Copy link
Member

willtebbutt commented Dec 19, 2024

Hi Miles. Thanks for trying out Mooncake!

The only source of potential stack overflows I'm currently aware of in Mooncake if you ask for the tangent_type of a type whose name appears in its own definition. For example, something like

julia> using Mooncake

julia> struct Foo
           x::Union{Foo, Nothing}
       end

julia> tangent_type(Foo)
ERROR: StackOverflowError:
Stacktrace:
 [1] tangent_type(::Type{Foo})
   @ Mooncake ./none:0
 [2] macro expansion
   @ ./none:0 [inlined]
 [3] tangent_type(::Type{Union{Nothing, Foo}})
   @ Mooncake ./none:0

It looks to me like your stack overflow is happening during a tangent_type call so, before I dig into your issue a bit further, could you confirm if e.g. your Expression or Node types have this property?

@MilesCranmer
Copy link
Author

Thanks! Yes the Node is recursive: https://ai.damtp.cam.ac.uk/dynamicexpressions/dev/api/#Nodes. It’s a binary tree structure. Any workarounds?

@willtebbutt
Copy link
Member

Ah, damn. Sadly there is not a work around at the minute.

Short explanation: Mooncake derives a "tangent type" for each Julia type it encounters -- it does this recursively. For a structs "primal" type, it produces something of the form Tangent{NameTuple{fieldnames, tangent_types_of_fields}}, where tangent_types_of_fields is a Tuple containing the result of tangent_type for each of the fields of the original struct. This is where the problem arises: this winds up being recursive if the name of the type appears in the type.

Enzyme circumvents this problem entirely by using the primal type as its own tangent type.

In the short term I can probably improve the error message to make it so that future users do not have to open an issue about this. In the medium term we should be able to add a macro that makes it a one-line fix to make this work correctly.

I'm going to label this issue as a "should have given a better error" issue for now.

@willtebbutt willtebbutt added the enhancement (error messages) The error was produced that should be improved upon label Dec 19, 2024
@yebai
Copy link
Contributor

yebai commented Dec 19, 2024

In the medium term we should be able to add a macro that makes it a one-line fix to make this work correctly.

@willtebbutt let's help implement this (or a simpiler version if it involves lots of work), so @MilesCranmer can run Mooncake with DynamicExpressions.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (error messages) The error was produced that should be improved upon
Projects
None yet
Development

No branches or pull requests

3 participants